diff --git a/.gitattributes b/.gitattributes index a4749a2c2e012db3683a05b8c4d27e9bf68bfe9b..248c3e40dd4501fa394fd484d1aad6af8d2caa66 100644 --- a/.gitattributes +++ b/.gitattributes @@ -14,3 +14,4 @@ build/torch29-cxx11-cu130-x86_64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lf build/torch210-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch210-cxx11-cu128-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text build/torch210-cxx11-cu130-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text +build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so filter=lfs diff=lfs merge=lfs -text diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a7b7f7a336dce0651f299d26b17df04952be99 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_layout_sm80_90.py @@ -0,0 +1,173 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLayout(EVTTestCaseBase): + + def test_permute_1(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D_permute = F_permute + permute(C, indices=(0, 2, 1)) + D = permute(D_permute, indices=(0, 2, 1)) + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_2(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(0, 2, 1)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, n, m)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, n, m)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() != 90, "This unittest is for cc = Sm90 only") + def test_permute_3(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_permute(accum, alpha, C): + F = alpha * accum + F_permute = permute(F, indices=(1, 0, 2)) + D = F_permute + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (m, l, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (m, l, n)), + } + + launcher = EVTTestBed(self.element, evt_permute, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_reshape(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + E_reshape = reshape(TensorE, new_shape=(512, 1)) + D = F + E_reshape + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (16, 32)), + "D": self.fake_tensor(self.element, (self.l, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + def test_reshape2(self): + """ + Test reshape + """ + def evt_reshape(accum, alpha, TensorE): + F = alpha * accum + F_reshape = reshape(F, new_shape=(2, 3, 512, 256)) + D = F_reshape + TensorE + return D + + example_inputs = { + "accum": self.fake_tensor(self.element, (self.l, self.m, self.n)), + "alpha": 0.5, + "TensorE": self.fake_tensor(self.element, (2, 3, 1, self.n)), + "D": self.fake_tensor(self.element, (2, 3, self.m, self.n)), + } + + launcher = EVTTestBed(self.element, evt_reshape, example_inputs) + input_keys = ["alpha", "TensorE"] + result_keys = ["D"] + launcher.verify(self.problem_size, input_keys, result_keys, self.l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..57a5c6bb17bb44bf294cc7a6a749c706601034f6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_load_sm80_90.py @@ -0,0 +1,142 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for load nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTLoad(EVTTestCaseBase): + + def test_tensor_load(self): + """ + Load extra tensor with shape [m, n] + """ + def evt_tensor_load(accum, C, aux, aux_batch): + D = accum + C + aux + aux_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "aux": self.fake_tensor(self.element, (m, n)), + "aux_batch": self.fake_tensor(np.float32, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_tensor_load, example_inputs) + input_keys = ["C", "aux", "aux_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_broadcast(self): + """ + Load extra tensor with shape [1, n] + """ + def evt_row_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (n,)), + "bias_batch": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_column_broadcast(self): + """ + Load extra tensor with shape [m, 1] + """ + def evt_column_broadcast(accum, C, bias, bias_batch): + D = accum + C + bias + bias_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m, 1)), + "bias_batch": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_column_broadcast, example_inputs) + input_keys = ["C", "bias", "bias_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_broadcast(self): + """ + Load extra tensor with shape [1, 1] + """ + def evt_scalar_broadcast(accum, C, alpha, alpha_batch): + D = accum + C + alpha + alpha_batch + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "C": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "alpha_batch": self.fake_tensor(np.float32, (l, 1, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_broadcast, example_inputs) + input_keys = ["C", "alpha", "alpha_batch"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..30dc8fe0d5ec413f1da57a8fa0875ed5e7baa887 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_mixed_sm80_90.py @@ -0,0 +1,319 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unittest for mixed types of nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * +from cutlass_cppgen.swizzle import ThreadblockSwizzleStreamK + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTMixed(EVTTestCaseBase): + + def test_same_variable_used_multiple_times(self): + """ + The same variable z0 is used multiple times + """ + def evt_aux_store(accum): + z0 = relu(accum) + D = z0 + z0 + return z0, D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "z0": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["accum"] + result_keys = ["z0", "D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_no_lca(self): + """ + The same variable z0 is used multiple times + """ + def evt_no_lca(accum, bias): + E = relu(accum) + F = E + bias + tmp_2 = E + 2 + D = tmp_2 + E + return D + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "bias": self.fake_tensor(self.element, (m,1), stride=(1,0)), + } + + launcher = EVTTestBed(self.element, evt_no_lca, example_inputs) + input_keys = ["accum", "bias"] + result_keys = ["D"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + if device_cc() == 80: + alignments = [2, 4, 8] + else: + # Sm90 EVT currently only supports 128-bit alignment + alignments = [8,] + for align in alignments: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_float(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for align in [3, 2, 4]: + for m, n, k, l in self.get_problem_sizes(align): + example_inputs = { + "accum": self.fake_tensor(np.float32, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(np.float32, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(np.float32, (l, m, n)), + "cbias": self.fake_tensor(np.float32, (m, 1)), + "rbias": self.fake_tensor(np.float32, (n,)), + "D": self.fake_tensor(np.float32, (l, m, n)), + "F": self.fake_tensor(np.float32, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (n,)), + "E_col_max": self.fake_tensor(np.float32, (m, 1)) + } + launcher = EVTTestBed(DataType.f32, evt_mixed_dag, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stage2(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_partition_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + tile_description = { + "threadblock_shape": [128, 128, 64], + "warp_count": [2, 2, 2] + } + + launcher = EVTTestBed(self.element, evt_mixed_dag, example_inputs, tile_description=tile_description, epilogue_stages=2) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + @unittest.skipIf(device_cc() not in [80, 89], "This unittest is for cc 80 and 89 only") + def test_mixed_dag_stream_k(self): + def evt_mixed_dag(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + # High per-sm occupancy tile_description + tile_description = { + "threadblock_shape": [128, 128, 32], + "warp_count": [2, 2, 1], + "stages": 3 + } + tds = [None, tile_description] + for td in tds: + for m, n, k, l in self.get_problem_sizes(8, k=960, batch_count=[1, 3]): + if l == 1: + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + else: + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (l, m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + if td is not None: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + tile_description=td, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + else: + launcher = EVTTestBed( + self.element, evt_mixed_dag, example_inputs, + swizzling_functor=ThreadblockSwizzleStreamK, backend="torch") + + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_mixed_dag_no_batch(self): + def evt_mixed_dag_no_batch(accum, alpha, C, beta, aux, cbias, rbias): + F = alpha * accum + (beta * C + aux) + F_row_max = max(F, dim=[0, 1]) + E = relu(F + 1) + cbias + rbias + E_col_max = max(E, dim=[0, 2]) + D = E + F + return D, F, F_row_max, E_col_max + + for m, n, k, _ in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (m, n)), + "alpha": 1.0, + "C": self.fake_tensor(self.element, (m, n)), + "beta": 1.0, + "aux": self.fake_tensor(self.element, (m, n)), + "cbias": self.fake_tensor(self.element, (m, 1)), + "rbias": self.fake_tensor(self.element, (n,)), + "D": self.fake_tensor(self.element, (m, n)), + "F": self.fake_tensor(self.element, (m, n)), + "F_row_max": self.fake_tensor(DataType.f32, (n,)), + "E_col_max": self.fake_tensor(DataType.f32, (m, 1)) + } + + launcher = EVTTestBed(self.element, evt_mixed_dag_no_batch, example_inputs) + input_keys = ["alpha", "C", "beta", "aux", "cbias", "rbias"] + result_keys = ["D", "F", "F_row_max", "E_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py new file mode 100644 index 0000000000000000000000000000000000000000..b47f11e4f3bde3499948ae68b1b5bb79347f0fd1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/evt_store_sm80_90.py @@ -0,0 +1,180 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Unit test for store nodes in SM90 +""" + +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend import * +from cutlass_cppgen.epilogue import * + +from utils.evt_testbed import EVTTestBed, EVTTestCaseBase + +cutlass_cppgen.set_log_level(logging.WARNING) + + +@unittest.skipIf(device_cc() not in [80, 86, 89, 90], "This unittest is only supported on CC [80, 86, 89, 90]") +class TestEVTStore(EVTTestCaseBase): + + @unittest.skipIf(device_cc() != 90, "This test is only for CC 90") + def test_invalid_store(self): + """ + Test invalid store + """ + def evt_invalid_store(accum): + D = accum + F = D + 1 # D has users, which is not allowed on SM90 or higher + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)) + } + with self.assertRaisesRegex( + RuntimeError, + r"On SM90 or higher, D is expected to be a output node with 0 users " + r"to enable smem reuse between C and D, but got 1" + ): + launcher = EVTTestBed(self.element, evt_invalid_store, example_inputs) + + break # Only need to test once + + def test_aux_store(self): + """ + Returning a tensor with shape [m, n] + """ + def evt_aux_store(accum, alpha, C): + F = alpha * accum + D = F + C + return D, F + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 0.5, + "C": self.fake_tensor(self.element, (l, m, n)), + "F": self.fake_tensor(self.element, (l, m, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_aux_store, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_col_reduce(self): + """ + Reduction [m, n] -> [m, 1] + """ + def evt_row_reduce(accum, alpha, C): + acc_row_max = max(accum, dim=[2,]) + F = alpha * accum + F_row_max = max(F, dim=[0, 2]) + D = F + C + return D, F_row_max, acc_row_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_row_max": self.fake_tensor(np.float32, (m, 1)), + "acc_row_max": self.fake_tensor(np.float32, (l, m, 1)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_row_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_row_max", "acc_row_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_row_reduce(self): + """ + Reduction [m, n] -> [n] + """ + def evt_col_reduce(accum, alpha, C): + acc_col_max = max(accum, dim=[1,]) + F = alpha * accum + F_col_max = max(F, dim=[0, 1]) + D = F + C + return D, F_col_max, acc_col_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "F_col_max": self.fake_tensor(np.float32, (n,)), + "acc_col_max": self.fake_tensor(np.float32, (l, 1, n)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_col_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_col_max", "acc_col_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + def test_scalar_reduce(self): + """ + Reduction [m, n] -> [1,] + """ + def evt_scalar_reduce(accum, alpha, C): + acc_max = max(accum, dim=[1, 2]) + F = alpha * accum + F_max = max(F, dim=[0, 1, 2]) + D = F + C + return D, F_max, acc_max + + for m, n, k, l in self.get_problem_sizes(8): + example_inputs = { + "accum": self.fake_tensor(self.element, (l, m, n)), + "alpha": 2.0, + "C": self.fake_tensor(self.element, (l, m, n)), + "acc_max": self.fake_tensor(np.float32, (l, 1, 1)), + "F_max": self.fake_tensor(np.float32, (1,)), + "D": self.fake_tensor(self.element, (l, m, n)), + } + + launcher = EVTTestBed(self.element, evt_scalar_reduce, example_inputs) + input_keys = ["C", "alpha"] + result_keys = ["D", "F_max", "acc_max"] + launcher.verify((m, n, k), input_keys, result_keys, l) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..5bb84e2e8c85e602b45b9ee18ce324accd3a32cd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'evt_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..62d375d856ffaef6be50b39b76121e0eb78a7465 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/evt/utils/evt_testbed.py @@ -0,0 +1,235 @@ +################################################################################ +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +""" +Testbed classes of EVT +""" + +import torch +import unittest + +import cutlass_cppgen +from cutlass_cppgen import Tensor +import cutlass_cppgen.backend.evt +from cutlass_cppgen.shape import GemmCoord +from cutlass_cppgen.utils.datatypes import torch_type +from cutlass_cppgen.utils.profiler import CUDAEventProfiler + + +class EVTReferenceModule: + def __init__(self, layout_A, layout_B, layout_C, epilogue_visitor): + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.epilogue_visitor = epilogue_visitor + + def run(self, A, B, C, problem_size, alpha, beta, batch=1): + if self.layout_A == cutlass_cppgen.LayoutType.RowMajor: + A_row = A.view((batch, problem_size.m, problem_size.k)) + else: + A_col = A.view((batch, problem_size.k, problem_size.m)) + A_row = torch.permute(A_col, (0, 2, 1)) + + if self.layout_B == cutlass_cppgen.LayoutType.RowMajor: + B_row = B.view((batch, problem_size.k, problem_size.n)) + else: + B_col = B.view((batch, problem_size.n, problem_size.k)) + B_row = torch.permute(B_col, (0, 2, 1)) + + if self.layout_C == cutlass_cppgen.LayoutType.RowMajor: + C_row = C.view((batch, problem_size.m, problem_size.n)) + else: + C_col = C.view((batch, problem_size.n, problem_size.m)) + C_row = torch.permute(C_col, (0, 2, 1)) + + out_row = torch.matmul(A_row, B_row) * alpha + C_row * beta + + if self.layout_C == cutlass_cppgen.LayoutType.ColumnMajor: + out = torch.permute(out_row, (0, 2, 1)) + else: + out = out_row + + return torch.flatten(out) + + def __call__(self, A, B, C, problem_size, batch=1, epilogue_args=None): + # Running the mainloop + accum = self.run( + A, B, C, problem_size, 1.0, 0.0, batch=batch + ).reshape(batch, problem_size.m, problem_size.n) + + # Running the epilogue + epilogue_args["accum"] = accum + references = self.epilogue_visitor(**epilogue_args) + + # Return the results + if not isinstance(references, tuple): + references = (references,) + return references + + +class EVTTestBed: + """ + Epilogue Visitor Testbed + """ + def __init__(self, element, evt_fn, example_inputs, profile=False, **kwargs) -> None: + self.element = element + layout = cutlass_cppgen.LayoutType.RowMajor + self.example_inputs = example_inputs + + # Create the Gemm plan + self.plan = cutlass_cppgen.op.Gemm(element=element, layout=layout, element_accumulator=torch.float32) + + if "tile_description" in kwargs: + self.plan.tile_description = kwargs["tile_description"] + + if "swizzling_functor" in kwargs: + self.plan.swizzling_functor = kwargs["swizzling_functor"] + + # Compile the epilogue visitor + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_fn, example_inputs) + if "epilogue_stages" in kwargs: + epilogue_visitor.epilogue_stages = kwargs["epilogue_stages"] + self.plan.epilogue_visitor = epilogue_visitor + + # Reference model + self.reference_fn = EVTReferenceModule(layout, layout, layout, epilogue_visitor) + + self.profile = profile + + def get_torch_tensor(self, shape, dtype=None, fill=None): + if dtype is None: + dtype = self.element + + dtype = torch_type(dtype) + if fill is None: + return torch.ceil( + torch.empty(size=shape, dtype=dtype, device="cuda").uniform_(-4.5, 3.5) + ) + else: + return torch.full(shape, fill, dtype=dtype, device="cuda") + + def verify(self, problem_size, input_keys, result_keys, batch_count=1): + """ + Verify the results + """ + problem_size = GemmCoord(*problem_size) + + # Initiate the GEMM arguments + tensor_A = self.get_torch_tensor((batch_count, problem_size.m, problem_size.k)) + tensor_B = self.get_torch_tensor((batch_count, problem_size.k, problem_size.n)) + + # Initialize the epilogue args + epilogue_args = {} + for key in self.example_inputs.keys(): + if key in input_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element) + else: + epilogue_args[key] = tensor + elif key in result_keys: + tensor = self.example_inputs[key] + if isinstance(tensor, Tensor): + if "max" in key: + fill = -1000 + else: + fill = 0 + epilogue_args[key] = self.get_torch_tensor(tensor.shape, tensor.element, fill=fill) + else: + epilogue_args[key] = tensor + + tensor_D = epilogue_args["D"] + if "C" in epilogue_args: + tensor_C = epilogue_args["C"] + else: + tensor_C = tensor_D + # Run the device kernel + self.plan.run(tensor_A, tensor_B, tensor_C, tensor_D, visitor_args=epilogue_args) + + # Run the host reference + evt_args_inputs = {} + for key in input_keys: + evt_args_inputs[key] = epilogue_args[key] + + reference_results = self.reference_fn( + tensor_A, tensor_B, tensor_C, problem_size, batch_count, evt_args_inputs) + + # Compare the results + for result, ref in zip(result_keys, reference_results): + assert torch.equal( + epilogue_args[result].flatten(), + ref.masked_fill(torch.isnan(ref), float('inf')).flatten()) + + # Run profile + if self.profile: + profiler = CUDAEventProfiler( + self.plan, 100, 100, tensor_A, tensor_B, tensor_C, tensor_D, + visitor_args = epilogue_args + ) + print(f"Cutlass Python Duration: {profiler()}") + + +class EVTTestCaseBase(unittest.TestCase): + """ + Base class for EVT Unittest + """ + def __init__(self, methodName: str = "runTest", lmnk=(6, 512, 256, 128)) -> None: + super().__init__(methodName) + + self.element = cutlass_cppgen.DataType.f16 + self.l, self.m, self.n, self.k = lmnk + + self.problem_size = (self.m, self.n, self.k) + + torch.random.manual_seed(42) + + def fake_tensor(self, element, shape, stride=None): + if stride is None: + return Tensor(element=element, shape=shape, layout_tag=cutlass_cppgen.LayoutType.RowMajor) + else: + return Tensor(element=element, shape=shape, stride=stride) + + def get_problem_sizes(self, alignment, k=None, batch_count=[3,]): + k = k if k else self.k + problem_size_m = [alignment, 512 - 3 * alignment] + problem_size_n = [alignment, 512 - alignment] + if alignment % 8 == 0: + problem_size_m.append(768) + problem_size_n.append(768) + problem_size_l = batch_count + problem_sizes = [] + for m in problem_size_m: + for n in problem_size_n: + for l in problem_size_l: + problem_sizes.append((m, n, k, l)) + + return problem_sizes diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py new file mode 100644 index 0000000000000000000000000000000000000000..155426ab902d1f99eafc7b03c388fc79b4520317 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_batched.py @@ -0,0 +1,134 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +High-level tests for running batched GEMMs +""" + +from functools import partial +import logging +from math import prod +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc +import torch + +from utils import LayoutCombination + +cutlass_cppgen.set_log_level(logging.WARNING) + +torch.manual_seed(2023) + + +def pytorch_reference(A, B, C, alpha, beta): + # Get the batch count. Assume that any of A, B, and C + # with a batch dimension ahve matching batch count. Thus, + # we break out of the loop once we have found the first + # tensor containing a batch dimension. + batch_count = (1,) + for tensor in [A, B, C]: + if len(tensor.shape) > 2: + batch_count = tensor.shape[:-2] + break + + int_batch_count = prod(batch_count) + + def add_batch(tensor): + if len(tensor.shape) == 2: + return tensor.unsqueeze(0).repeat(int_batch_count, 1, 1) + else: + return tensor.reshape(-1, tensor.size(-2), tensor.size(-1)) + + # Reshape tensors to have batch dimension + A = add_batch(A) + B = add_batch(B) + C = add_batch(C) + + ret = (torch.bmm(A, B) * alpha) + (C * beta) + reshape_vals = batch_count + C.shape[-2:] + return ret.reshape(*reshape_vals) + + +def initialize(rows, cols, batch): + tensor = torch.randint(-3, 3, size=(rows*cols*prod(batch),), device='cuda').half() + if len(batch) > 0 and prod(batch) > 1: + reshape_vals = batch + (rows, cols) + return tensor.reshape(*reshape_vals) + else: + return tensor.reshape(rows, cols) + + +class GemmF16Batched(unittest.TestCase): + def run_batched(self, batch_count: tuple, batch_A: bool, batch_B: bool, batch_C: bool): + M = 512 + N = 256 + K = 128 + alpha = 1. + beta = 2. + + A = initialize(M, K, batch_count if batch_A else (1,)) + B = initialize(K, N, batch_count if batch_B else (1,)) + C = initialize(M, N, batch_count if batch_C else (1,)) + D = initialize(M, N, batch_count) + + plan = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=cutlass_cppgen.DataType.f32) + plan.run(A, B, C, D, alpha, beta) + reference = pytorch_reference(A, B, C, alpha, beta) + assert reference.equal(D) + + def test_batched_ABC(self): + self.run_batched((3,), True, True, True) + self.run_batched((2, 3), True, True, True) + + def test_batched_AB(self): + self.run_batched((3,), True, True, False) + self.run_batched((2, 3), True, True, False) + + def test_batched_AC(self): + self.run_batched((3,), True, False, True) + self.run_batched((2, 3), True, False, True) + + def test_batched_BC(self): + self.run_batched((3,), False, True, True) + self.run_batched((2, 3), False, True, True) + + def test_batched_A(self): + self.run_batched((3,), True, False, False) + self.run_batched((2, 3), True, False, False) + + def test_batched_B(self): + self.run_batched((3,), False, True, False) + self.run_batched((2, 3), False, True, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..dbd26951ec5d8a1eb6cbe38491c64fde2873b9c3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -0,0 +1,128 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..61aa295b966daf5943e7092572c98ee20143e2b5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -0,0 +1,146 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F16 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f16 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF16Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF16Sm90, element=dtype, + warp_count=None, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_unit_cluster = partial(add_test_tensorop, cluster_shape=[1, 1, 1]) +add_test_unit_cluster(layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=3) +add_test_unit_cluster(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 64], stages=5) +add_test_unit_cluster(layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 32], stages=None) + +# Tests with different cluster shapes +add_test_cluster_shape = partial(add_test_tensorop, threadblock_shape=[64, 128, 64], stages=None) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f16, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.NNN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 4, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 1, 1]) +add_test_cluster_shape(layouts=LayoutCombination.TTN, alignments=[8, 8, 4], element_output=cutlass_cppgen.DataType.f32, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[4, 2, 1]) + +# Tests for different schedule modes +add_test_schedule = partial(add_test_specialized, layouts=LayoutCombination.TTN, alignments=[8, 8, 4], + element_output=cutlass_cppgen.DataType.f32, element_accumulator=cutlass_cppgen.DataType.f32, + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], stages=None) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[1, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized +) +add_test_schedule( + cluster_shape=[2, 1, 1], + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative +) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], stages=2) +add_test_simt(layouts=LayoutCombination.NNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 8]) +add_test_simt(layouts=LayoutCombination.TNN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 128, 8]) +add_test_simt(layouts=LayoutCombination.NTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 64, 8]) +add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[ 64, 64, 8]) +add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, threadblock_shape=[128, 128, 8]) + +# Tests with void-C kernels +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass_cppgen.DataType.f16, + element_accumulator=cutlass_cppgen.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass_cppgen.DataType.void) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..bf662b9208ab2a5343d0fd11106835b7d9a5b2e9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -0,0 +1,104 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F32 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f32 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF32Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..3075ddf74bf2a119759ca1a3e47c0815f4b0923c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=dtype, element_C=dtype, + element_accumulator=dtype, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..9bf36fc77436fef22882e98c752b7a599cf7fb95 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f64_sm90.py @@ -0,0 +1,71 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with F64 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.f64 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF64Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF64Sm90, alignments=[1, 1, 1], cluster_shape=[1, 1, 1], + element=dtype, element_output=dtype, element_accumulator=dtype, compilation_modes=['nvcc']) + +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.NNT, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized(opclass=cutlass_cppgen.OpcodeClass.TensorOp, layouts=LayoutCombination.TNN, threadblock_shape=[128, 128, 32], stages=3) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.NNN, threadblock_shape=[128, 128, 8], stages=2) +add_test_specialized( opclass=cutlass_cppgen.OpcodeClass.Simt, layouts=LayoutCombination.TTT, threadblock_shape=[ 64, 128, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..fef6d457a6528a61613d1295877a2b6b8f80fef5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_f8_sm90.py @@ -0,0 +1,112 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.e4m3 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E4M3Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E4M3Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Test with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.e4m3, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +# +# Add a test for E5M2 +# +dtype = cutlass_cppgen.DataType.e5m2 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmF8E5M2Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmF8E5M2Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=dtype, + element_accumulator=cutlass_cppgen.DataType.f32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..0a002a5fbad80de5f7b29e42db0806469244914c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with mixed operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype =cutlass_cppgen.DataType.f16 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmMixedSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=dtype, cc=cc, cluster_shape=[1, 1, 1], + opclass=cutlass_cppgen.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass_cppgen.DataType.f32) + +# Test with upcast on A +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass_cppgen.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) + +# Test with upcast on B +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass_cppgen.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py new file mode 100644 index 0000000000000000000000000000000000000000..e226e23684147cb0a9cd5c1270468eb96c67ba15 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -0,0 +1,103 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 80 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm80StreamK(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, element=dtype, cc=cc, cluster_shape=[1, 1, 1]) + +# Tests using TensorOp +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) + +# Tests using SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) + +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s32, element_C=cutlass_cppgen.DataType.s32, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) + +# Stream K tests +add_test_streamk = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp, swizzle=cutlass_cppgen.swizzle.ThreadblockSwizzleStreamK) +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, element_C=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0101f78da3b62b599a5deeb89f5596a7e515ce --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_s8_sm90.py @@ -0,0 +1,98 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Low-level functionality tests for GEMM with S8 operands on SM90 +""" + +from functools import partial +import logging +import unittest + +import cutlass_cppgen +from cutlass_cppgen.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass_cppgen.set_log_level(logging.WARNING) +cc = 90 +dtype = cutlass_cppgen.DataType.s8 + + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM90 tests.') +@unittest.skipIf(cutlass_cppgen.utils.datatypes.torch_type(dtype) is None, f'Version of torch installed does not contain a datatype match for {dtype}') +class GemmS8Sm90(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_specialized = partial(add_test_gemm, cls=GemmS8Sm90, element=dtype, compilation_modes=['nvcc']) + +add_test_tensorop = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.TensorOp) + +# Tests with 1x1x1 clusters +add_test_tensorop(layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=3) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 8], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 64, 32], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[ 4, 4, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with different cluster shapes +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 2, 1], threadblock_shape=[128, 128, 128], stages=None) +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 4, 1], threadblock_shape=[128, 128, 128], stages=None) + +# Tests with warp-specialized ping-pong schedule +add_test_tensorop(layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[2, 1, 1], threadblock_shape=[128, 128, 128], stages=None, + kernel_schedule=cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong, + epilogue_schedule=cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized) + +# Tests for SIMT +add_test_simt = partial(add_test_specialized, opclass=cutlass_cppgen.OpcodeClass.Simt) +add_test_simt(layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass_cppgen.DataType.s8, + element_accumulator=cutlass_cppgen.DataType.s32, cluster_shape=[1, 1, 1], threadblock_shape=[64, 32, 8], stages=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py new file mode 100644 index 0000000000000000000000000000000000000000..6ffda5b47e37f184c2352f0ee4e737635dbd4147 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/gemm_testbed.py @@ -0,0 +1,423 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from math import prod +import os +import re +import subprocess + +import torch + +from cutlass_library import ( + DataType, + DataTypeSize, + GemmUniversalMode, + LayoutType, + OpcodeClass, + ShortDataTypeNames, + SwizzlingFunctor +) + +from cutlass_cppgen.backend import compiler +from cutlass_cppgen.backend.gemm_operation import GemmArguments, GemmOperationUniversal +from cutlass_cppgen.backend.reduction_operation import ReductionArguments, ReductionOperation +from cutlass_cppgen.shape import GemmCoord, MatrixCoord +from cutlass_cppgen.utils.datatypes import torch_type + + +class GemmUniversalLauncher: + def __init__( + self, + operation, + seed=2080, + verification=True, + iterations=500, + compiler_mode= "nvcc", + **kwargs, + ) -> None: + self.math_operation = operation.tile_description.math_instruction.math_operation + self.verification = verification + + if compiler_mode == "nvcc": + compiler.nvcc() + elif compiler_mode == "nvrtc": + compiler.nvrtc() + else: + raise Exception(f"Unexpected compiler string {compiler_mode}") + + op_list = [operation] + if operation.arch < 90: + # Split K via Python is currently only supported for pre-SM90 kernels + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) + op_list.append(self.reduction_operation) + + compiler.add_module(op_list, bypass_cache=False) + + self.operation = operation + + self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) + self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) + self.dtype_C = torch_type(operation.C.element) + self.dtype_D = torch_type(operation.epilogue_functor.element_output) + + element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) + + if element_size == 1: + self.rand_max = 1 + self.rand_min = 0 + elif element_size <= 8: + self.rand_max = 1 + self.rand_min = -1 + elif element_size == 16: + self.rand_max = 4 + self.rand_min = -4 + else: + self.rand_max = 8 + self.rand_min = -8 + + self.seed = seed + + self.compute_type = operation.epilogue_functor.element_epilogue + self.accumulator_type = operation.tile_description.math_instruction.element_accumulator + + def print_problem_size(self, p, mode, batch_count): + if mode == GemmUniversalMode.Gemm: + mode = "Gemm" + elif mode == GemmUniversalMode.Batched: + mode = "GemmBatched" + elif mode == GemmUniversalMode.GemmSplitKParallel: + mode = "GemmSplitKParallel" + print(f"problem: {p.m}, {p.n}, {p.k}\n batch_count: {batch_count}\n mode: {mode}") + + def uniform_init(self, shape, dtype, layout): + size = prod(shape) + if dtype.is_floating_point: + # Initialize data in FP32 and call convert to the data type we desire. + # This is a workaround for the following error that occurs when attempting to + # call uniform_ on a tensor with torch.float8_e4m3fn data: + # RuntimeError: "check_uniform_bounds" not implemented for 'Float8_e4m3fn' + data = torch.ceil( + torch.empty(size=(size,), dtype=torch.float32, device="cuda").uniform_( + self.rand_min - 0.5, self.rand_max - 0.5) + ).to(dtype) + else: + # PyTorch does not currently support integer-typed matrix multiplications on GPU. + # Fall back to CPU for integer type references. + data = torch.empty(size=(size,), dtype=dtype, device="cpu").random_(self.rand_min, self.rand_max + 1) + + is_fp8 = dtype == getattr(torch, "float8_e4m3fn", -1) or dtype == dtype == getattr(torch, "float8_e5m2", -1) + + if dtype == torch.float64 or dtype == torch.float32 or is_fp8: + data = data.to("cpu") + + data_ref = data.reshape(shape) + + if layout == LayoutType.RowMajor: + data_cutlass = data_ref + else: + data_cutlass = data_ref.transpose(-1, -2).contiguous() + + data_cutlass = data_cutlass.to("cuda") + + # As of this writing, few operations in PyTorch are supported with FP8 data. + # Thus, we perform computation in FP32 for FP8 reference checks. + if is_fp8: + data_ref = data_ref.to(torch.float32) + + return data_cutlass, data_ref + + def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): + # If any tensor is on CPU, place all tensors on CPU unless only + # tensor C is on CPU + # Handle mixed-input cases by casting to the larger data type and overriding + # to whatever the data type of the larger type is + if self.dtype_A != self.dtype_B: + if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: + tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) + else: + tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) + + devices = [x.device.type for x in [tensor_A, tensor_B]] + if tensor_C is not None: + devices.append(tensor_C.device.type) + + if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: + device = torch.device("cpu") + else: + device = tensor_A.device + + tensor_A = tensor_A.to(device) + tensor_B = tensor_B.to(device) + if tensor_C is not None: + tensor_C = tensor_C.to(device) + + dtype = torch_type(self.compute_type) + alpha_torch = torch.tensor([alpha], device=device).to(dtype) + beta_torch = torch.tensor([beta], device=device).to(dtype) + + tmp = tensor_A @ tensor_B + tensor_D_ref = (alpha_torch * tmp) + if tensor_C is not None: + tensor_D_ref += (tensor_C * beta_torch) + return tensor_D_ref.to(self.dtype_D) + + def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): + torch.random.manual_seed(self.seed) + + # Assign an actual batch count in cases where we are not running in batched mode. + # This is to differentiate between the number of split K slices and the batch count, + # which are overloaded within the single `batch_count` variable. + if mode == GemmUniversalMode.Batched: + true_batch_count = batch_count + else: + true_batch_count = 1 + + def transpose(layout): + if layout == LayoutType.RowMajor: + return LayoutType.ColumnMajor + else: + return LayoutType.RowMajor + + tensor_A, tensor_A_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.k), + self.dtype_A, + self.operation.A.layout if not self.operation.switched else transpose(self.operation.B.layout), + ) + tensor_B, tensor_B_ref = self.uniform_init( + (true_batch_count, problem_size.k, problem_size.n), + self.dtype_B, + self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), + ) + if self.dtype_C is not None: + tensor_C, tensor_C_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_C, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + else: + tensor_C = None + tensor_C_ref = None + + tensor_D, _ = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_D, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + tensor_D = torch.zeros_like(tensor_D) + + if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: + alpha = int(alpha) + beta = int(beta) + + # + # Launch kernel + # + + arguments = GemmArguments( + operation=self.operation, + problem_size=problem_size, + A=tensor_A, + B=tensor_B, + C=tensor_C, + D=tensor_D, + output_op=self.operation.epilogue_type(alpha, beta), + gemm_mode=mode, + split_k_slices=split_k_slices, + batch=batch_count, + ) + + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments = ReductionArguments( + self.reduction_operation, + problem_size=[problem_size.m, problem_size.n], + partitions=split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op=self.reduction_operation.epilogue_type(alpha, beta), + ) + + self.operation.run(arguments) + + if mode == GemmUniversalMode.GemmSplitKParallel: + self.reduction_operation.run(reduction_arguments) + + passed = True + + if self.verification: + if mode == GemmUniversalMode.GemmSplitKParallel: + reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() + else: + arguments.sync() + tensor_D_ref = self.reference( + problem_size, + tensor_A_ref, + tensor_B_ref, + tensor_C_ref, + alpha, + beta, + ) + + tensor_D_ref = tensor_D_ref.to('cuda') + + if self.operation.switched or self.operation.C.layout == LayoutType.ColumnMajor: + tensor_D = tensor_D.transpose(-1, -2).contiguous() + + passed = tensor_D.equal(tensor_D_ref) + + try: + assert passed + except AssertionError: + self.print_problem_size(problem_size, mode, batch_count) + del arguments + if mode == GemmUniversalMode.GemmSplitKParallel: + del reduction_arguments + + return passed + + +def test_all_gemm(operation: "GemmOperationUniversal", testcase="universal", compilation_mode="nvcc"): + passed = True + + minimum_operand_element_size = min( + DataTypeSize[operation.A.element], DataTypeSize[operation.B.element] + ) + opcode_class = operation.tile_description.math_instruction.opcode_class + + if opcode_class == OpcodeClass.Simt: + alignment = 1 + else: + alignment = 128 // minimum_operand_element_size + + alignment_m = alignment + alignment_n = alignment + alignment_k = alignment + + # INT8 alignment constraints + if opcode_class == OpcodeClass.Simt: + A_is_s8 = operation.A.element == DataType.s8 + B_is_s8 = operation.B.element == DataType.s8 + + if A_is_s8 and operation.A.layout == LayoutType.ColumnMajor: + alignment_m = 4 + if B_is_s8 == DataType.s8 and operation.A.layout == LayoutType.RowMajor: + alignment_n = 4 + if A_is_s8 and B_is_s8 and (operation.A.layout == LayoutType.RowMajor or operation.B.layout == LayoutType.ColumnMajor): + alignment_k = 4 + + threadblock_k = operation.tile_description.threadblock_shape[2] + + assert testcase != "interleaved" + + supports_split_k = operation.arch < 90 and not operation.swizzling_functor == SwizzlingFunctor.StreamK + + if testcase == "multistage": + modes = [GemmUniversalMode.Gemm] + problem_size_m = [16, 528] + problem_size_n = [16, 528] + problem_size_k = [ + threadblock_k, + threadblock_k * operation.tile_description.stages + + operation.tile_description.math_instruction.instruction_shape[2], + ] + problem_alpha = [1.0] + problem_beta = [0.0] + batch_counts = [1] + else: + modes = [GemmUniversalMode.Gemm] + batch_counts = [1, 2, 3, 5, 7] + if supports_split_k: + modes.append(GemmUniversalMode.GemmSplitKParallel) + + problem_size_m = [alignment_m, 512 - 3 * alignment_m] + problem_size_n = [alignment_n, 512 - 2 * alignment_n] + if operation.tile_description.stages is None: + stages_for_k_calc = 7 + else: + stages_for_k_calc = operation.tile_description.stages + problem_size_k = [ + alignment_k, + threadblock_k * stages_for_k_calc - alignment_k, + threadblock_k * stages_for_k_calc * 3 - alignment_k, + ] + problem_alpha = [1.0] + problem_beta = [2.0] + + testbed = GemmUniversalLauncher(operation, compiler_mode=compilation_mode) + + for mode in modes: + for m in problem_size_m: + for n in problem_size_n: + for k in problem_size_k: + for batch_count in batch_counts: + for alpha in problem_alpha: + for beta in problem_beta: + # skip very small K problems + if testcase == "universal": + if k // batch_count < 2 * threadblock_k: + continue + + problem_size = GemmCoord(m, n, k) + + if supports_split_k: + split_k_slices = batch_count + else: + split_k_slices = 1 + + overridden_mode = mode + if mode == GemmUniversalMode.Gemm and batch_count > 1: + overridden_mode = GemmUniversalMode.Batched + + passed = testbed.run( + overridden_mode, + problem_size, + batch_count, + split_k_slices, + alpha, + beta, + ) + + if not passed: + return False + + return passed diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5e7467b1e0040ce3012ff8541dfbac381bb861 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/run_all_tests.py @@ -0,0 +1,44 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import pathlib +import unittest + + +if __name__ == '__main__': + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, 'gemm_*.py') + testRunner = unittest.runner.TextTestRunner() + results = testRunner.run(tests) + if not results.wasSuccessful(): + raise Exception('Test cases failed') diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28bba3e922961c96df75f8685e3064ab55cbbc87 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/gemm/utils.py @@ -0,0 +1,260 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from cutlass_library import SubstituteTemplate + +import cutlass_cppgen +from cutlass_library import ( + DataTypeNames, + EpilogueScheduleSuffixes, + KernelScheduleSuffixes, + LayoutType, + OpcodeClassNames, + ShortDataTypeNames, + ShortLayoutTypeNames +) +from cutlass_cppgen.backend import library + +from gemm_testbed import test_all_gemm + + +class Layout: + """ + Utility class to map transpose and non-transpose terminology to row- and column-major terminology + """ + + T = LayoutType.RowMajor + N = LayoutType.ColumnMajor + + +class LayoutCombination: + """ + Utility class defining all combinations of row- and column-major layouts for operands to a GEMMs + """ + + NNN = (Layout.N, Layout.N, Layout.N) + NNT = (Layout.N, Layout.N, Layout.T) + NTN = (Layout.N, Layout.T, Layout.N) + NTT = (Layout.N, Layout.T, Layout.T) + TNN = (Layout.T, Layout.N, Layout.N) + TNT = (Layout.T, Layout.N, Layout.T) + TTN = (Layout.T, Layout.T, Layout.N) + TTT = (Layout.T, Layout.T, Layout.T) + + +def get_name( + layouts, + alignments, + element_output, + element_accumulator, + element_epilogue, + cluster_shape, + threadblock_shape, + stages, + element_a, + element_b, + element_c, + arch, + opclass, + kernel_schedule=None, + epilogue_schedule=None, + suffix="", +): + """ + Generates a procedural name for a test case. + + :param layouts: indexable container of layouts of A, B, and C operands + :param alignments: indexable container of alignments of A, B, and C operands + :param element_output: data type of the output element + :param element_accumulator: data type used in accumulation + :param element_epilogue: data type used in computing the epilogue + :param cluster_shape: indexable container of dimensions of threadblock cluster to be launched + :param threadblock_shape: indexable container of dimensions of threadblock tiles + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param element_a: data type of operand A + :param element_b: data type of operand B + :param element_c: data type of operand C + :param arch: compute capability of kernel being generated + :type arch: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param kernel_schedule: kernel_schedule type + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue_schedule type + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param suffix: additional string to add to the suffix of the name + :type suffix: str + + :return: str + """ + name_format = "test_SM${arch}_Device_Gemm_${eA}${lA}_${eB}${lB}_${eC}${lC}_${opclass}_${acc}_${tbM}x${tbN}x${tbK}_${cM}x${cN}x${cK}_${stages}_align${aA}-${aB}-${aC}${k}${e}${suffix}" + return SubstituteTemplate( + name_format, + { + "arch": str(arch), + "eA": DataTypeNames[element_a], + "eB": DataTypeNames[element_b], + "eC": DataTypeNames[element_c], + "lA": ShortLayoutTypeNames[layouts[0]], + "lB": ShortLayoutTypeNames[layouts[1]], + "lC": ShortLayoutTypeNames[layouts[2]], + "opclass": OpcodeClassNames[opclass], + "acc": DataTypeNames[element_accumulator], + "cM": str(cluster_shape[0]), + "cN": str(cluster_shape[1]), + "cK": str(cluster_shape[2]), + "tbM": str(threadblock_shape[0]), + "tbN": str(threadblock_shape[1]), + "tbK": str(threadblock_shape[2]), + "stages": str(stages) if stages is not None else "auto", + "aA": str(alignments[0]), + "aB": str(alignments[1]), + "aC": str(alignments[2]), + "k": "" if kernel_schedule is None else KernelScheduleSuffixes[kernel_schedule], + "e": "" if epilogue_schedule is None else EpilogueScheduleSuffixes[epilogue_schedule], + "suffix": "" if suffix is None else suffix, + }, + ) + + +def add_test_gemm( + cls=None, + cc=None, + element=None, + layouts=None, + alignments=None, + element_output=None, + element_accumulator=None, + cluster_shape=None, + threadblock_shape=None, + warp_count=None, + stages=None, + opclass=None, + swizzle=None, + kernel_schedule=None, + epilogue_schedule=None, + compilation_modes=['nvcc', 'nvrtc'], + element_A=None, + element_B=None, + element_C=None): + """ + Create test-running functions with the given specification and set it as a method of ``cls``. + + :param cls: class to which the generated method will be added + :type cls: type + :param cc: compute capability to compile for + :type cc: int + :param element: data type of A and B operands + :type element: cutlass_cppgen.DataType.f16 + :param layouts: layouts of A, B, and C operands + :type layouts: list or tuple + :param alignments: alingments of A, B, and C operands + :type alignments: list or tuple + :param element_output: data type of the output element + :type element_output: cutlass_cppgen.DataType + :param element_accumulator: data type used in accumulation + :type element_accumulator: cutlass_cppgen.DataType + :param cluster_shape: dimensions of clusters + :type cluster_shape: list or tuple + :param threadblock_shape: dimensions of threadblock tiles + :type threadblock_shape: list or tuple + :param warp_count: warps to be launched per threadblock dimension + :type warp_count: list or tuple + :param stages: number of pipeline stages to use in the kernel + :type stages: int + :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) + :type opclass: cutlass_cppgen.OpcodeClass + :param swizzle: threadblock swizzling functor + :param kernel_schedule: kernel schedule to use + :type kernel_schedule: cutlass_cppgen.KernelScheduleType + :param epilogue_schedule: epilogue schedule to use + :type epilogue_schedule: cutlass_cppgen.EpilogueScheduleType + :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') + :type compilation_modes: list, + :param element_A: data type of operand A. If set, overrides ``element`` + :type element_A: cutlass_cppgen.DataType + :param element_B: data type of operand B. If set, overrides ``element`` + :type element_B: cutlass_cppgen.DataType + :param element_C: data type of operand C. If set, overrides ``element`` + :type element_C: cutlass_cppgen.DataType + """ + + if element_A is None: + element_A = element + if element_B is None: + element_B = element + if element_C is None: + element_C = element + if element_output is None: + element_output = element + if element_accumulator is None: + element_accumulator = element + + for compilation_mode in compilation_modes: + def run(self): + """ + Dynamically-generated function that constructs a GEMM operation and verifies it against + multiple test cases. + """ + + layout_A, layout_B, layout_C = layouts + alignment_A, alignment_B, alignment_C = alignments + + plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_output, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, + element_accumulator=element_accumulator, + kernel_cc=cc) + + plan.opclass = opclass + if swizzle is not None: + plan.swizzling_functor = swizzle + + td = plan.tile_descriptions()[0] + + if warp_count is not None: + td.warp_count = warp_count + td.threadblock_shape = threadblock_shape + td.stages = stages + td.cluster_shape = cluster_shape + op = plan.construct(tile_description=td, alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + self.assertTrue(test_all_gemm(op, 'universal', compilation_mode=compilation_mode)) + + element_epilogue = element_accumulator + name = get_name( + layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, + element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, + stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, + kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') + + setattr(cls, name, run) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py new file mode 100644 index 0000000000000000000000000000000000000000..f550c394812c7fede55070e4c99c4471a69c2f88 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/installation.py @@ -0,0 +1,57 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests for a successful installation of the CUTLASS Python interface +""" + +import os +import unittest + +import cutlass_cppgen +import cutlass_library + + +class InstallationTest(unittest.TestCase): + def test_cutlass_source_paths(self): + """ + Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages + """ + src_file = 'include/cutlass/cutlass.h' + library_file = os.path.join(cutlass_library.source_path, src_file) + cutlass_file = os.path.join(cutlass_cppgen.CUTLASS_PATH, src_file) + assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." + assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2b5d46d45d617198a46bec85cd7218cb5431a7b1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/conv2d_interface.py @@ -0,0 +1,284 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level Conv2d interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException +import os + + +class Conv2dEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Conv2d interface + """ + def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, + alignment_A, alignment_B, alignment_C): + + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + + self.conv_kind = conv_kind + + self.plan = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator) + + self.op = self.plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default Conv2d + :type other_plan: cutlass_cppgen.op.Conv2d + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct( + alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_C=self.alignment_C) + + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types + and layouts for constructing the Conv2d interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors using generic element and output + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Conv2d( + kind=self.conv_kind, + element_C=self.element_C, + element_D=self.element_D, + element=self.element_A) + assert self._plans_equal(plan_other) + + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator): + plan_other = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=self.element_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + size = (2, 2) + A = np.zeros(size, dtype=type_A) + B = np.zeros(size, dtype=type_B) + C = np.zeros(size, dtype=type_C) + D = np.zeros(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def torch_test(self): + """ + Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend + """ + if not datatypes.is_torch_available(): + return + + import torch + type_A = datatypes.torch_type(self.element_A) + type_B = datatypes.torch_type(self.element_B) + type_C = datatypes.torch_type(self.element_C) + type_D = datatypes.torch_type(self.element_D) + type_accum = datatypes.torch_type(self.element_accumulator) + + size = (2, 2) + + A = torch.empty(size, dtype=type_A) + B = torch.empty(size, dtype=type_B) + C = torch.empty(size, dtype=type_C) + D = torch.empty(size, dtype=type_D) + + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) + + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + if type_A == type_B: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): + plan_np = cutlass_cppgen.op.Conv2d(kind=self.conv_kind, element=type_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + self.torch_test() + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class ConvEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Conv2d interface + """ + pass + +type2alignment = { + cutlass_cppgen.DataType.f16: 8, + cutlass_cppgen.DataType.f32: 4 +} + +def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): + + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" + + def run(self): + conv2d_eq = Conv2dEquivalence( + conv_kind=conv_kind, + element_A=element_A, element_B=element_B, + element_C=element_C, element_D=element_D, + element_accumulator=element_accumulator, + alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], + alignment_C=type2alignment[element_C] + ) + conv2d_eq.test_all() + + setattr(ConvEquivalenceTest, test_name, run) + +for conv_kind in ["fprop", "wgrad", "dgrad"]: + for types in [ + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f16], + [cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f16, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32], + [cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32, cutlass_cppgen.DataType.f32] + ]: + add_test(conv_kind, types[0], types[1], types[2], types[3], types[4]) + + +@unittest.skipIf(device_cc() <= 80, 'Device compute capability is insufficient for SM80 tests.') +class Conv2dErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + plan = cutlass_cppgen.op.Conv2d(kind="fprop", element=cutlass_cppgen.DataType.f16) + + td = plan.tile_descriptions()[0] + td.threadblock_shape=[17, 32, 5] + + plan.tile_description = td + with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): + plan.compile() + # Clean up the error message + os.remove("./cutlass_python_compilation_device_error.txt") + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d67f4d07f01b0936ff5796bfb6fe4c98b5c031 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/evt_interface.py @@ -0,0 +1,254 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Test the EVT interface +""" + +import numpy as np +import unittest + +import cutlass_cppgen +from cutlass_cppgen import LayoutType, Tensor +from cutlass_cppgen.backend.utils.device import device_cc +from cutlass_cppgen.epilogue import reshape, permute + +from utils import ExpectException + + +@unittest.skipIf(device_cc() not in [80, 90], "This unittest is for Sm80 and Sm90 only") +class EVTErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the EVT interface + """ + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT requires root node be 'D'") + def test_root_not_d(self): + """ + Test when "D" does not exist in Sm90 EVT + """ + def evt_root_not_d(accum, alpha): + F = accum * alpha + return F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(device_cc() == 90, + "SyntaxError: Sm90 EVT requires the epilogue to have a returned tensor D, " + "but the variable 'D' is not found in the return values.", True): + + cutlass_cppgen.epilogue.trace(evt_root_not_d, example_tensors) + + def test_no_accum(self): + """ + Test when "accum" is not in input arguments + """ + def evt_no_accum(alpha, C): + D = alpha * C + return D + + example_tensors = { + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.2, + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Cannot find 'accum' in the argument list.", True): + cutlass_cppgen.epilogue.trace(evt_no_accum, example_tensors) + + @unittest.skipIf(device_cc() != 90, "Only Sm90 EVT has concern on smem size") + def test_too_much_shared_memory(self): + """ + Test when the epilogue consumes too much shared memory + """ + def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): + D1 = accum + C1 + D2 = D1 + C2 + D3 = D2 + C3 + D4 = D3 + C4 + D5 = D4 + C5 + D6 = D5 + C6 + D7 = D6 + C7 + D = D7 + C8 + return D, D1, D2, D3, D4, D5, D6, D7 + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C1": self.fake_tensor(np.float16, (6, 512, 512)), + "C2": self.fake_tensor(np.float16, (6, 512, 512)), + "C3": self.fake_tensor(np.float16, (6, 512, 512)), + "C4": self.fake_tensor(np.float16, (6, 512, 512)), + "C5": self.fake_tensor(np.float16, (6, 512, 512)), + "C6": self.fake_tensor(np.float16, (6, 512, 512)), + "C7": self.fake_tensor(np.float16, (6, 512, 512)), + "C8": self.fake_tensor(np.float16, (6, 512, 512)), + "D1": self.fake_tensor(np.float16, (6, 512, 512)), + "D2": self.fake_tensor(np.float16, (6, 512, 512)), + "D3": self.fake_tensor(np.float16, (6, 512, 512)), + "D4": self.fake_tensor(np.float16, (6, 512, 512)), + "D5": self.fake_tensor(np.float16, (6, 512, 512)), + "D6": self.fake_tensor(np.float16, (6, 512, 512)), + "D7": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + epilogue_visitor = cutlass_cppgen.epilogue.trace(evt_too_much_shared_memory, example_tensors) + + plan = cutlass_cppgen.op.Gemm( + element=np.float16, layout=cutlass_cppgen.LayoutType.RowMajor, + element_accumulator=np.float32 + ) + + with ExpectException(True, + "RuntimeError: The epilogue consumes too much shared memory. " + "No valid tile description is found in the generator.", True): + plan.epilogue_visitor = epilogue_visitor + + def test_not_ssa(self): + """ + Test when the epilogue is not in SSA + """ + def evt_redefine(accum, C, alpha): + F = accum + C + F = F * alpha + D = F + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'F' cannot be defined twice.", True): + cutlass_cppgen.epilogue.trace(evt_redefine, example_tensors) + + def evt_undefine(accum, alpha): + F = accum + C + D = F * alpha + return D, F + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "alpha": 1.5, + "D": self.fake_tensor(np.float16, (6, 512, 512)), + "F": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, "SyntaxError: Variable 'C' is undefined.", True): + cutlass_cppgen.epilogue.trace(evt_undefine, example_tensors) + + def test_missing_example_tensor(self): + """ + Test when the example tensor of an input/output variable is not provided + """ + def evt_missing_example_tensor(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for D is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "RuntimeError: Example input for C is not provided.", True): + cutlass_cppgen.epilogue.trace(evt_missing_example_tensor, example_tensors) + + def test_return_expression(self): + """ + Test when the return value is an expression + """ + def evt_return_expr(accum, C): + return accum + C + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + } + + with ExpectException(True, "SyntaxError: Return value cannot be an expression", True): + cutlass_cppgen.epilogue.trace(evt_return_expr, example_tensors) + + def test_incompatible_shape(self): + """ + Test when the shape of example tensors are incompatible + """ + def evt_incompatible_shape(accum, C): + D = accum + C + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 256, 512)), + "C": self.fake_tensor(np.float16, (6, 512, 512)), + "D": self.fake_tensor(np.float16, (6, 512, 512)) + } + + with ExpectException(True, + "RuntimeError: Dimension mismatch between accum(6, 256, 512), C(6, 512, 512).", True): + cutlass_cppgen.epilogue.trace(evt_incompatible_shape, example_tensors) + + def test_no_matching_impl(self): + def evt_no_matching_impl(accum, bias): + D = accum + reshape(permute(bias, indices=(1, 0)), new_shape=(512, 1)) + return D + + example_tensors = { + "accum": self.fake_tensor(np.float16, (6, 512, 256)), + "bias": self.fake_tensor(np.float16, (16, 32)), + "D": self.fake_tensor(np.float16, (6, 512, 256)) + } + + with ExpectException(True, "NotImplementedError: No matching op for node bias with stride (0, (1, 32), 0).", True): + cutlass_cppgen.epilogue.trace(evt_no_matching_impl, example_tensors) + # + # Helper functions + # + + def fake_tensor(self, element, shape): + return Tensor(element=element, shape=shape, layout_tag=LayoutType.RowMajor) + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..2913d5933f5342cc58b4f252657a724d2c7692da --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/gemm_interface.py @@ -0,0 +1,354 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Tests the high-level GEMM interface +""" + +from math import ceil +import unittest + +import cutlass_cppgen +import cutlass_cppgen.utils.datatypes as datatypes +from cutlass_cppgen.backend.utils.device import device_cc +from utils import ExpectException + + +class GemmEquivalence: + """ + Helper class for testing the equivalence of different constructions of the Gemm interface + """ + def __init__(self, element_A, element_B, element_C, element_D, element_accumulator, + layout_A, layout_B, layout_C, alignment_A, alignment_B, alignment_C): + self.element_A = element_A + self.element_B = element_B + self.element_C = element_C + self.element_D = element_D + self.element_accumulator = element_accumulator + self.layout_A = layout_A + self.layout_B = layout_B + self.layout_C = layout_C + self.alignment_A = alignment_A + self.alignment_B = alignment_B + self.alignment_C = alignment_C + self.plan = cutlass_cppgen.op.Gemm(element_A=element_A, element_B=element_B, element_C=element_C, + element_D=element_D, element_accumulator=element_accumulator, + layout_A=layout_A, layout_B=layout_B, layout_C=layout_C) + self.op = self.plan.construct(alignment_A=alignment_A, alignment_B=alignment_B, alignment_C=alignment_C) + + def _plans_equal(self, other_plan) -> bool: + """ + Compares whether two plans are equal + + :param other_plan: plan to compare against the default GEMM + :type other_plan: cutlass_cppgen.op.Gemm + + :return: whether `other_plan` is equivalent to `self.plan` + :rtype: bool + """ + other_op = other_plan.construct(alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) + + # Compare whether the operations are equal by comparing the C++ code that would be emitted for them + return self.op.rt_module.emit() == other_op.rt_module.emit() + + def generic_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types + and layouts for constructing the Gemm interface + """ + if not datatypes.is_numpy_available(): + return + + # Test when specifying all parameters + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_A=self.layout_A, layout_B=self.layout_B, layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A + plan_other = cutlass_cppgen.op.Gemm(element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_B=self.layout_B, layout_C=self.layout_C, + element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if self.element_A == self.element_B and self.layout_A == self.layout_B: + plan_other = cutlass_cppgen.op.Gemm(element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator, + layout_C=self.layout_C, element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if self.element_C == self.element_accumulator: + plan_other = cutlass_cppgen.op.Gemm(element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, + element_D=self.element_D, layout_A=self.layout_A, layout_B=self.layout_B, + layout_C=self.layout_C) + assert self._plans_equal(plan_other) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D + and self.element_A == self.element_accumulator and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_other = cutlass_cppgen.op.Gemm(element=self.element_A, layout=self.layout_A) + assert self._plans_equal(plan_other) + + def numpy_test(self): + """ + Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend + """ + if not datatypes.is_numpy_available(): + return + + import numpy as np + type_A = datatypes.numpy_type(self.element_A) + type_B = datatypes.numpy_type(self.element_B) + type_C = datatypes.numpy_type(self.element_C) + type_D = datatypes.numpy_type(self.element_D) + type_accum = datatypes.numpy_type(self.element_accumulator) + + layout_to_order = { + cutlass_cppgen.LayoutType.RowMajor: 'C', + cutlass_cppgen.LayoutType.ColumnMajor: 'F' + } + size = (2, 2) + A = np.zeros(size, order=layout_to_order[self.layout_A], dtype=type_A) + B = np.zeros(size, order=layout_to_order[self.layout_B], dtype=type_B) + C = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_C) + D = np.zeros(size, order=layout_to_order[self.layout_C], dtype=type_D) + + # Test when specifying all parameters via tensors + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D, element_accumulator=type_accum) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A as tensors + plan_np = cutlass_cppgen.op.Gemm(B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A, layout_A=self.layout_A) + assert self._plans_equal(plan_np) + + # Test when specifying all parameters but A and B as tensors and using generic element and output + # Only run this test if the layouts and types for A and B are equal. + if type_A == type_B and self.layout_A == self.layout_B: + plan_np = cutlass_cppgen.op.Gemm(C=C, D=D, element_accumulator=type_accum, element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + # Test without explicit accumulator. Only run if the type of C and the accumulator. + if type_C == type_accum: + plan_np = cutlass_cppgen.op.Gemm(A=A, B=B, C=C, D=D) + assert self._plans_equal(plan_np) + + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. + if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum and + self.layout_A == self.layout_B and self.layout_A == self.layout_C): + plan_np = cutlass_cppgen.op.Gemm(element=type_A, layout=self.layout_A) + assert self._plans_equal(plan_np) + + def test_all(self): + """ + Runs all tests on the Gemm interface + """ + self.generic_test() + self.numpy_test() + + +class GemmEquivalenceTest(unittest.TestCase): + """ + Tests the equivalence of different constructions of the Gemm interface + """ + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f32_ntn_8_8_8(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f32, + layout_A=cutlass_cppgen.LayoutType.ColumnMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.ColumnMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for FP16 Tensor Core tests.") + def test_gemm_equivalence_f16_f16_f16_f16_f16_ttt_4_4_4(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f16, element_B=cutlass_cppgen.DataType.f16, element_C=cutlass_cppgen.DataType.f16, + element_D=cutlass_cppgen.DataType.f16, element_accumulator=cutlass_cppgen.DataType.f16, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.RowMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=8, alignment_B=8, alignment_C=8) + gemm_eq.test_all() + + @unittest.skipIf(device_cc() < 80, "Device compute capability is insufficient for F64 Tensor Core tests.") + def test_gemm_equivalence_f64_f64_f64_f64_f64_tnt_1_1_1(self): + gemm_eq = GemmEquivalence( + element_A=cutlass_cppgen.DataType.f64, element_B=cutlass_cppgen.DataType.f64, element_C=cutlass_cppgen.DataType.f64, + element_D=cutlass_cppgen.DataType.f64, element_accumulator=cutlass_cppgen.DataType.f64, + layout_A=cutlass_cppgen.LayoutType.RowMajor, layout_B=cutlass_cppgen.LayoutType.ColumnMajor, layout_C=cutlass_cppgen.LayoutType.RowMajor, + alignment_A=1, alignment_B=1, alignment_C=1) + gemm_eq.test_all() + + +class GemmErrorTests(unittest.TestCase): + """ + Tests various error scenarios that arise with the high-level Gemm interface + """ + + def test_alignment(self): + """ + Tests case in which the alignment specified is unsupported + """ + plan = cutlass_cppgen.op.Gemm(element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + + with ExpectException(True, 'Alignment 16 is not supported for F16. The construction should fail.'): + op = plan.construct(alignment_A=16, alignment_B=16, alignment_C=16) + + def test_tensorop_availability(self): + """ + Tests case in which only SIMT operations are available but TensorOp is requested + """ + cc = device_cc() + + # F64 Tensor Core operations are only avaiable on certain devices + supports_tensorop_f64 = cc in [80, 89, 90] + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f64, layout=cutlass_cppgen.LayoutType.RowMajor) + + error_msg = f'Incorrectly raised an exception for availability of TensorOp with F64 operands on SM{cc}' + with ExpectException(not supports_tensorop_f64, error_msg): + plan.opclass = cutlass_cppgen.OpcodeClass.TensorOp + + expected_opclass = cutlass_cppgen.OpcodeClass.TensorOp if supports_tensorop_f64 else cutlass_cppgen.OpcodeClass.Simt + assert plan.opclass == expected_opclass, f'Expected opclass to be {expected_opclass}, but received {plan.opclass} for SM{cc}' + + @unittest.skipIf(device_cc() < 70, "Device compute capability is insufficient for F16 Tensor Core tests.") + def test_opclass_switch(self): + """ + Tests cases in which the opcode class in question is switched (e.g., from TensorOp to SIMT) + """ + plan = cutlass_cppgen.op.Gemm( element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + assert plan.opclass == cutlass_cppgen.OpcodeClass.TensorOp + + # Ensure that all tile descriptions have opclass of TensorOp + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.TensorOp + + plan.opclass = cutlass_cppgen.OpcodeClass.Simt + + # Ensure that all tile descriptions have opclass of Simt + for td in plan.tile_descriptions(): + assert td.math_instruction.opcode_class == cutlass_cppgen.OpcodeClass.Simt + + def test_invalid_tile_description(self): + """ + Tests scenarios in which an invalid tile description is provided for a given CC + """ + cc = device_cc() + plan = cutlass_cppgen.op.Gemm(cc=cc, element=cutlass_cppgen.DataType.f16, layout=cutlass_cppgen.LayoutType.RowMajor) + td = plan.tile_descriptions()[0] + stages = td.stages + + # Zero stage count is valid for SM90+, as this is used to indicate that the builder's auto stage + # count should be used + with ExpectException(cc < 90, f'Requested zero stages'): + td.stages = 0 + plan.construct(td) + + if cc < 90: + with ExpectException(cc < 80, f'Requested more than 2 stages on SM{cc}'): + td.stages = 3 + plan.construct(td) + elif cc == 90: + original_kschedule = td.kernel_schedule + original_eschedule = td.epilogue_schedule + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.NoSmemWarpSpecialized + td.stages = 3 + plan.construct(td) + # Reset schedules + td.kernel_schedule = original_kschedule + td.epilogue_schedule = original_eschedule + elif cc in [100, 101, 103]: + with ExpectException(False, f'Incorrectly flagged an error for insufficient shared memory'): + td.stages = 3 + plan.construct(td) + + with ExpectException(True, f'Requested too many stages'): + td.stages = 100 + plan.construct(td) + + # Reset stage count + td.stages = stages + + cluster_shape = td.cluster_shape + with ExpectException(cc < 90, f'Requested non-unit cluster shape on SM{cc}'): + td.cluster_shape = [2, 1, 1] + plan.construct(td) + + # Reset cluster shape + td.cluster_shape = cluster_shape + + with ExpectException(cc < 90, f'Requested a non-auto schedule on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc == 90, f'Requested a non-auto kernel schedule with an auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedPingpong + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.ScheduleAuto + plan.construct(td) + + with ExpectException(cc == 90, f'Requested an auto kernel schedule with a non-auto epilogue schedule'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.ScheduleAuto + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecialized + plan.construct(td) + + with ExpectException(cc < 90, f'Requested a tile scheduler on SM{cc}'): + td.kernel_schedule = cutlass_cppgen.KernelScheduleType.TmaWarpSpecializedCooperative + td.epilogue_schedule = cutlass_cppgen.EpilogueScheduleType.TmaWarpSpecializedCooperative + td.tile_scheduler = cutlass_cppgen.TileSchedulerType.StreamK + plan.construct(td) + + # Ensure that all returned tile descriptions are unique + ops = {} + for i, td in enumerate(plan.tile_descriptions()): + op = plan.construct(td) + code_str = op.rt_module.emit() + if code_str in ops: + conflicting_td = ops[code_str] + assert False, f'Multiple tile descriptions emitted {code_str}\nTile descriptions are:\n{td}\n{conflicting_td}' + + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9f93ca26e2d79a15dab4dd0045836ebd9fe62757 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/cutlass/interface/utils.py @@ -0,0 +1,69 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Helper functions & classes for interface test +""" +class ExpectException: + """ + Utility class to assert that an exception was raised when expected + + Example: + + .. highlight:: python + .. code-block:: python + + with ExceptionExpected(True, 'Division by zero'): + x = 1.0 / 0.0 + + :param exception_expected: whether an exception is expected to be raised + :type exception_expected: bool + :param message: message to print if an exception is raised when not expected or vice versa + :type message: str + """ + def __init__(self, exception_expected: bool, message: str = '', verify_msg=False): + self.exception_expected = exception_expected + self.message = message + self.verify_msg = verify_msg + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, traceback): + exception_raised = exc_type is not None + assert self.exception_expected == exception_raised, self.message + if self.verify_msg: + exc_message = f"{exc_type.__name__}: {exc_val}" + assert exc_message == self.message, f"expect error message {self.message}, got {exc_message}" + + # Suppress the exception + return True diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py new file mode 100644 index 0000000000000000000000000000000000000000..b7cdc421ccffffeb7bd1696aaf9916330a6625ca --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/run_all_tests.py @@ -0,0 +1,75 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Utility script for discovering and running all PyCuTe tests +""" + +import argparse +import logging +import pathlib +import unittest + + +def numeric_log_level(log_level: str) -> int: + """ + Converts the string identifier of the log level into the numeric identifier used + in setting the log level + + :param x: string representation of log level (e.g., 'INFO', 'DEBUG') + :type x: str + + :return: numeric representation of log level + :rtype: int + """ + numeric_level = getattr(logging, log_level.upper(), None) + if not isinstance(numeric_level, int): + raise ValueError(f"Invalid log level: {log_level}") + return numeric_level + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, + help='Logging level to be used by the generator script') + args = parser.parse_args() + + # Set the logging level based on the user-provided `--log-level` command-line option + logging.basicConfig(level=args.log_level) + + loader = unittest.TestLoader() + script_dir = str(pathlib.Path(__file__).parent.resolve()) + '/' + tests = loader.discover(script_dir, "test_*.py") + test_runner = unittest.runner.TextTestRunner() + results = test_runner.run(tests) + if not results.wasSuccessful(): + raise Exception("Test cases failed") diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py new file mode 100644 index 0000000000000000000000000000000000000000..d4330377cab7079ea16422f194ddf4f2403ea507 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_coalesce.py @@ -0,0 +1,95 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.coalesce +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestCoalesce(unittest.TestCase): + def helper_test_coalesce(self, layout): + layoutR = coalesce(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + self.assertEqual(size(layoutR), size(layout)) + + for i in range(size(layout)): + self.assertEqual(layoutR(i), layout(i)) + + def test_coalesce(self): + layout = Layout(1,0) + self.helper_test_coalesce(layout) + + layout = Layout(1,1) + self.helper_test_coalesce(layout) + + layout = Layout((2,4)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (1,6,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (1,7,2)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,6), (4,7,8)) + self.helper_test_coalesce(layout) + + layout = Layout((2,(4,6))) + self.helper_test_coalesce(layout) + + layout = Layout((2,4), (4,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,4,6), (24,6,1)) + self.helper_test_coalesce(layout) + + layout = Layout((2,1,3), (2,4,4)) + self.helper_test_coalesce(layout) + + layout = Layout(((2,2),(2,2)), ((1,4),(8,32))) + self.helper_test_coalesce(layout) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py new file mode 100644 index 0000000000000000000000000000000000000000..5a8684a55b19c90eae11ddd1cca011c2ff8270b5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_complement.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.complement +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComplement(unittest.TestCase): + def helper_test_complement(self, layout): + layoutR = complement(layout) + + _LOGGER.debug(f"{layout} => {layoutR}") + + # Post-condition: test disjointness of the codomains + for a in range(size(layout)): + for b in range(size(layoutR)): + assert (layout(a) != layoutR(b)) or (layout(a) == 0 and layoutR(b) == 0) + + def test_complement(self): + test = Layout(1,0) + self.helper_test_complement(test) + + test = Layout(1,1) + self.helper_test_complement(test) + + test = Layout(4,0) + self.helper_test_complement(test) + + test = Layout((2,4),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,3),(1,2)) + self.helper_test_complement(test) + + test = Layout((2,4),(1,4)) + self.helper_test_complement(test) + + test = Layout((2,4,8),(8,1,64)) + self.helper_test_complement(test) + + test = Layout(((2,2),(2,2)),((1,4),(8,32))) + self.helper_test_complement(test) + + test = Layout((2,(3,4)),(3,(1,6))) + self.helper_test_complement(test) + + test = Layout((4,6),(1,6)) + self.helper_test_complement(test) + + test = Layout((4,10),(1,10)) + self.helper_test_complement(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py new file mode 100644 index 0000000000000000000000000000000000000000..6c27eb7fe6cbb7bbbea7bd644ac8e64a2fc853c9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_composition.py @@ -0,0 +1,213 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.composition +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestComposition(unittest.TestCase): + def helper_test_composition(self, layoutA, layoutB): + layoutR = composition(layoutA, layoutB) + + _LOGGER.debug(f"{layoutA} o {layoutB} => {layoutR}") + + # True post-condition: Every coordinate c of layoutB with L1D(c) < size(layoutR) is a coordinate of layoutR. + + # Test that R(c) = A(B(c)) for all coordinates c in layoutR + for i in range(size(layoutR)): + self.assertEqual(layoutR(i), layoutA(layoutB(i))) + + def test_composition(self): + layoutA = Layout(1,0) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,0) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,0) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(1,1) + layoutB = Layout(1,1) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((4), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((1), (0)) + layoutB = Layout((4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((1), (0)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4), (2)) + layoutB = Layout((2), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12), (2)) + layoutB = Layout((4,3), (3,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((12)) + layoutB = Layout((2,3), (2,4)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((4,3)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((12)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6), (2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,3), (3,1)) + layoutB = Layout((6,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((8,8), (8,1)) + layoutB = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((2,2,2), (2,2,2)),((1,16,4), (8,2,32))) + layoutB = Layout(8, 4) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout(((4,2)), ((1,16))) + layoutB = Layout((4,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((2,2), (2,1)) + layoutB = Layout((2,2), (2,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2)) + layoutB = Layout((2,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((2,2,2), (1,8,2)) + self.helper_test_composition(layoutA, layoutB) + + layoutA = Layout((4,8,2), (2,8,1)) + layoutB = Layout((4,2,2), (2,8,1)) + self.helper_test_composition(layoutA, layoutB) + + # Pre-coalesced LHS + layoutA = Layout((4,6,8),(1,4,7)) + layoutB = Layout((6),(1)) + self.helper_test_composition(layoutA, layoutB) + + # Mid-layout truncation + layoutA = Layout((4,6,8,10),(2,3,5,7)) + layoutB = Layout(6,12) + self.helper_test_composition(layoutA, layoutB) + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py new file mode 100644 index 0000000000000000000000000000000000000000..0dbf443c9725735b0051d0a225a55eece9c663a8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_int_tuple.py @@ -0,0 +1,80 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.int_tuple +""" + +import unittest + +from pycute import * + + +class TestIntTuple(unittest.TestCase): + def test_product(self): + self.assertEqual(product(2), 2) + + self.assertEqual(product((3,2)), 6) + + self.assertEqual(product(product(((2,3),4))), 24) + + def test_inner_product(self): + self.assertEqual(inner_product(2, 3), 6) + + self.assertEqual(inner_product((1,2), (3,2)), 7) + + self.assertEqual(inner_product(((2,3),4), ((2,1),2)), 15) + + def test_shape_div(self): + self.assertEqual(shape_div((3,4), 6), (1,2)) + + self.assertEqual(shape_div((3,4), 12), (1,1)) + + self.assertEqual(shape_div((3,4), 36), (1,1)) + + self.assertEqual(shape_div(((3,4),6), 36), ((1,1),2)) + + self.assertEqual(shape_div((6,(3,4)), 36), (1,(1,2))) + + def test_prefix_product(self): + self.assertEqual(prefix_product(2), 1) + + self.assertEqual(prefix_product((3,2)), (1,3)) + + self.assertEqual(prefix_product((3,2,4)), (1,3,6)) + + self.assertEqual(prefix_product(((2,3),4)), ((1,2),6)) + + self.assertEqual(prefix_product(((2,3),(2, 1, 2),( 5, 2, 1))), + ((1,2),(6,12,12),(24,120,240))) + + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..a6501fd6c7c6fc5a518e4d22bf93dc0e4746a8ba --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_left_inverse.py @@ -0,0 +1,87 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestLeftInverse(unittest.TestCase): + def helper_test_left_inverse(self, layout): + inv_layout = left_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(layout)): + self.assertEqual(inv_layout(layout(i)), i) + + def test_left_inverse(self): + test = Layout(1,0) + self.helper_test_left_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_left_inverse(test) + + test = Layout(1,1) + self.helper_test_left_inverse(test) + + test = Layout(4,1) + self.helper_test_left_inverse(test) + + test = Layout(4,2) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_left_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_left_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_left_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_left_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed9759d7808da8087fe9c76761d2dd9eaeab08b --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_right_inverse.py @@ -0,0 +1,96 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.left_inverse +""" + +import logging +import unittest + +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestRightInverse(unittest.TestCase): + def helper_test_right_inverse(self, layout): + inv_layout = right_inverse(layout) + + _LOGGER.debug(f"{layout} => {inv_layout}") + + for i in range(size(inv_layout)): + self.assertEqual(layout(inv_layout(i)), i) + + def test_right_inverse(self): + test = Layout(1,0) + self.helper_test_right_inverse(test) + + test = Layout((1,1),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout((3,7),(0,0)) + self.helper_test_right_inverse(test) + + test = Layout(1,1) + self.helper_test_right_inverse(test) + + test = Layout(4,0) + self.helper_test_right_inverse(test) + + test = Layout(4,1) + self.helper_test_right_inverse(test) + + test = Layout(4,2) + self.helper_test_right_inverse(test) + + test = Layout((2,4),(0,2)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(1,8)) + self.helper_test_right_inverse(test) + + test = Layout((8,4),(4,1)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(1,2,8)) + self.helper_test_right_inverse(test) + + test = Layout((2,4,6),(4,1,8)) + self.helper_test_right_inverse(test) + + test = Layout((4,2),(1,16)) + self.helper_test_right_inverse(test) + + +if __name__ == "__main__": + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb99a4833529e18fa22d65a235ce80dad372365 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/python/pycute/test_typing.py @@ -0,0 +1,59 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +""" +Unit tests for pycute.typing +""" + +import logging +import unittest +from pycute import * + +_LOGGER = logging.getLogger(__name__) + + +class TestTyping(unittest.TestCase): + def helper_test_typing(self, _cls, _obj, cls, expected: bool): + _LOGGER.debug(f"issubclass({_cls}, {cls})") + _LOGGER.debug(f"isinstance({_obj}, {cls})") + + self.assertEqual(expected, issubclass(_cls, cls)) + self.assertEqual(expected, isinstance(_obj, cls)) + + def test_typing(self): + self.helper_test_typing(int, 1, Integer, True) + self.helper_test_typing(float, 1., Integer, False) + self.helper_test_typing(str, 'hi', Integer, False) + self.helper_test_typing(bool, False, Integer, False) + +if __name__ == '__main__': + unittest.main() diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h new file mode 100644 index 0000000000000000000000000000000000000000..86b7823785a9f2a957cf505740d6cfde45ccfef1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/common/cutlass_unit_test.h @@ -0,0 +1,102 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#pragma warning (disable : 4068 ) /* disable unknown pragma warnings for visual studio */ + +#pragma nv_diag_suppress boolean_controlling_expr_is_constant +#include +#pragma nv_diag_warning boolean_controlling_expr_is_constant +#pragma warning( disable : 4503) + +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gets a CUDA device +cudaDeviceProp GetCudaDevice(); + +/// Prints device properties +std::ostream &operator<<(std::ostream &out, cudaDeviceProp const &device); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Sets flags for Unit test +void FilterArchitecture(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Reads environment variable `CUTLASS_UNIT_TEST_PROBLEM_COUNT` to control the number and order +// of problem sizes run by CUTLASS unit tests +int CutlassUnitTestProblemCount(); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// active test macro +#define CUTLASS_TEST_LEVEL_ACTIVE(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,L##LEVEL##_##NAME_DYNAMIC) __VA_ARGS__ + +// disabled test macro +#define CUTLASS_TEST_LEVEL_DISABLED(LEVEL,NAME_STATIC,NAME_DYNAMIC,...) \ + TEST(NAME_STATIC,DISABLED_L##LEVEL##_##NAME_DYNAMIC) {} + +#if CUTLASS_TEST_LEVEL == 0 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#elif CUTLASS_TEST_LEVEL == 1 +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_DISABLED(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#else +#define CUTLASS_TEST_L0(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(0,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L1(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(1,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#define CUTLASS_TEST_L2(NAME_STATIC,NAME_DYNAMIC,...) CUTLASS_TEST_LEVEL_ACTIVE(2,NAME_STATIC,NAME_DYNAMIC,__VA_ARGS__) +#endif + +#if !defined(CUTLASS_TEST_UNIT_ENABLE_WARNINGS) +#define CUTLASS_TEST_UNIT_ENABLE_WARNINGS false +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12) + #define CUDA_12_0_SM90_FEATURES_SUPPORTED true +#else + #define CUDA_12_0_SM90_FEATURES_SUPPORTED false +#endif + +#include +#include +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h new file mode 100644 index 0000000000000000000000000000000000000000..3035e9862bcb79b749b4cbc4a74341bceac9c598 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/cache_testbed_output.h @@ -0,0 +1,907 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Helper to construct cached name for +*/ +#pragma once + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "thrust/universal_vector.h" + +#ifndef CUTLASS_TEST_ENABLE_CACHED_RESULTS +#define CUTLASS_TEST_ENABLE_CACHED_RESULTS false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result of a test +struct CachedTestKey { + + std::string op; ///< Concatenated string representation of operation performed + std::string problem; ///< Concatenated string representation of problem description + std::string types; ///< Concatenated string representation of operand types + uint32_t A; ///< Hashed result of tensor A + uint32_t B; ///< Hashed result of tensor B + uint32_t C; ///< Hashed result of tensor C + + // + // Methods + // + inline CachedTestKey(): A(), B(), C() { } + + inline CachedTestKey( + std::string op, ///< Concatenated string representation of operation performed + std::string problem, ///< Concatenated string representation of problem description + std::string types, ///< Concatenated string representation of operand types + uint32_t A, ///< Hashed result of tensor A + uint32_t B, ///< Hashed result of tensor B + uint32_t C ///< Hashed result of tensor C + ): + op(op), problem(problem), types(types), A(A), B(B), C(C) + { } + + /// Checks for equality of the problem + bool operator==(CachedTestKey const &rhs) const { + return op == rhs.op && problem == rhs.problem && types == rhs.types && A == rhs.A && B == rhs.B && C == rhs.C; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestKey &result) { + + in >> result.op; + in >> result.problem; + in >> result.types; + in >> result.A; + in >> result.B; + in >> result.C; + + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestKey const &result) { + + out << result.op << " "; + out << result.problem << " "; + out << result.types << " "; + out << result.A << " "; + out << result.B << " "; + out << result.C << " "; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResult { + uint32_t D; + // + // Methods + // + + CachedTestResult(): D() + { } + + CachedTestResult(uint32_t D): D(D) + { } + + operator bool() const { + return bool(D); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline std::istream &operator>>(std::istream &in, CachedTestResult &result) { + in >> result.D; + return in; +} + +inline std::ostream &operator<<(std::ostream &out, CachedTestResult const &result) { + out << result.D; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct CachedTestResultListing { + + std::list> results; + + // + // Methods + // + + inline CachedTestResultListing(std::string const &path) { + std::ifstream file(path); + + while (file.good()) { + CachedTestKey key; + file >> key; + + CachedTestResult result; + file >> result; + + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + } + + /// Returns the cached result + std::pair find(CachedTestKey const &rhs) const { + for (auto const & result : results) { + if (result.first == rhs) { + return std::make_pair(true, result.second); + } + } + return std::make_pair(false, CachedTestResult()); + } + + /// Appends an entry + void append(CachedTestKey const &key, CachedTestResult const &result) { + if (result) { + results.push_back(std::make_pair(key, result)); + } + } + + /// Writes the entire listing to a file + bool write(std::string const &path) { + std::ofstream file(path); + if (!file.good()) { + return false; + } + + for (auto const &result : results) { + file << result.first << result.second << std::endl; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ScalarEncoder { + Element scalar; + + ScalarEncoder(Element s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + Element s = scalar; + if (s < Element()) { + s = -s; + ss << "n"; + } + ss << s; + return ss.str(); + } +}; + +template +ScalarEncoder EncodeScalar(Element a) { + return ScalarEncoder(a); +} + +template +struct ScalarEncoder> { + cutlass::complex scalar; + + ScalarEncoder(cutlass::complex s): scalar(s) { } + + std::string str() const { + std::stringstream ss; + ss << EncodeScalar(scalar.real()) << "_" << EncodeScalar(scalar.imag()) << "i"; + return ss.str(); + } +}; + +template +std::ostream &operator<<(std::ostream &out, ScalarEncoder const &scalar) { + out << scalar.str(); + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline char const *EncodeOperator(cutlass::conv::Operator conv_op) { + switch (conv_op) { + case cutlass::conv::Operator::kFprop: return "fprop"; + case cutlass::conv::Operator::kDgrad: return "dgrad"; + case cutlass::conv::Operator::kWgrad: return "wgrad"; + case cutlass::conv::Operator::kDeconv: return "deconv"; + } + return "conv_unknown"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode GemmCoord (Gemm problem size) +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::gemm::GemmCoord const &problem) { + + out << problem.m() << "x" << problem.n() << "x" << problem.k() << "_"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode Conv2dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv2dProblemSize const &problem) { + + out << problem.N << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Encode Conv3dProblemSize +inline std::ostream &EncodeProblemSize( + std::ostream &out, + cutlass::conv::Conv3dProblemSize const &problem) { + + out << problem.N << "x" << problem.D << "x" << problem.H << "x" << problem.W << "x" << problem.C << "_" + << problem.Z << problem.P << "x" << problem.Q << "_" << problem.K << "x" << problem.R << "x" << problem.S << "_"; + + out << "pad_d" << problem.pad_h << "h" << problem.pad_h << "w" << problem.pad_w << "_"; + out << "stride_d" << problem.stride_d << "h" << problem.stride_h << "w" << problem.stride_w << "_"; + out << "dil_d" << problem.dilation_d << "h" << problem.dilation_h << "w" << problem.dilation_w << "_"; + + switch (problem.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Encode 3.x ConvNd ProblemShape +template +inline std::ostream &EncodeProblemSize( + std::ostream &out, + ProblemShape const& problem_shape) { + + out << problem_shape.shape_A << "_"; + out << problem_shape.shape_B << "_"; + + out << "padl" << problem_shape.lower_padding << "_"; + out << "padu" << problem_shape.upper_padding << "_"; + out << "str" << problem_shape.traversal_stride << "_"; + out << "dil" << problem_shape.dilation << "_"; + + switch (problem_shape.mode) { + case cutlass::conv::Mode::kCrossCorrelation: + out << "corr"; + break; + case cutlass::conv::Mode::kConvolution: + out << "conv"; + break; + } + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string ElementTypeName() { + return std::string(typeid(Element).name()); +} + +template <> +inline std::string ElementTypeName() { + return "h"; +} + +template <> +inline std::string ElementTypeName>() { + return "ch"; +} + +template <> +inline std::string ElementTypeName() { + return "bf16"; +} + +template <> +inline std::string ElementTypeName>() { + return "cbf16"; +} + +template <> +inline std::string ElementTypeName() { + return "tf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "ctf32"; +} + +template <> +inline std::string ElementTypeName>() { + return "c"; +} + +template <> +inline std::string ElementTypeName>() { + return "z"; +} + +template <> +inline std::string ElementTypeName>() { + return "q"; +} + +template <> +inline std::string ElementTypeName() { + return "s8"; +} + +template <> +inline std::string ElementTypeName() { + return "u8"; +} + +template <> +inline std::string ElementTypeName() { + return "s4"; +} + +template <> +inline std::string ElementTypeName() { + return "u4"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string LayoutTypeName() { + return std::string(typeid(Layout).name()); +} + +template <> +inline std::string LayoutTypeName() { + return "n"; +} + +template <> +inline std::string LayoutTypeName() { + return "t"; +} + +template <> +inline std::string LayoutTypeName() { + return "nhwc"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc32hw32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "nc64hw64"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c32rsk32"; +} + +template <> +inline std::string LayoutTypeName>() { + return "c64rsk64"; +} + +template <> +inline std::string LayoutTypeName() { + return "ndhwc"; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName() << LayoutTypeName(); + return ss.str(); +} + +template +inline std::string TensorTypeName() { + std::stringstream ss; + ss << ElementTypeName(); + return ss.str(); +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function on a byte array +struct CRC32 { + + uint32_t table[256]; + + // + // Methods + // + + CRC32() { + + uint32_t rem; + int i, j; + + for (i = 0; i < 256; i++) { + rem = i; + for (j = 0; j < 8; j++) { + if (rem & 1) { + rem >>= 1; + rem ^= 0xedb88320; + } else + rem >>= 1; + } + table[i] = rem; + } + } + + /// Computes the CRC of an array of bytes + uint32_t operator()(void const *start, size_t length, uint32_t crc = uint32_t()) const { + uint8_t const *p = static_cast(start); + uint8_t const *q = static_cast(start) + length; + + crc = ~crc; + + for (; p != q; ++p) { + uint8_t octet = *p; + crc = (crc >> 8) ^ table[(crc & 0xff) ^ octet]; + } + + return ~crc; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Element, typename Layout +> +uint32_t TensorHash( + cutlass::TensorView view, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(view.data(), view.capacity() * cutlass::sizeof_bits::value / 8, crc); +} + +template +uint32_t TensorHash( + thrust::universal_vector& tensor, + CRC32 const &hash = CRC32(), + uint32_t crc = uint32_t() +) { + + return hash(tensor.data().get(), tensor.size() * cutlass::sizeof_bits::value / 8, crc); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName() << "_" + << ElementTypeName(); + + return out; +} + +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline std::ostream &EncodeTypes( + std::ostream &out +) { + + out << TensorTypeName() << "_" + << TensorTypeName() << "_" + << TensorTypeName() << "_" + << ElementTypeName(); + + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedGemmTestKey( + cutlass::gemm::GemmCoord const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode gemm operator and problem sizes + key.op = "gemm"; + + std::stringstream ss_problem; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithBroadcastTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_broadcast"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv2dWithReductionTestKey( + + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv2d operator and problem sizes + key.op = "conv2d_with_reduction"; + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode hash for problem data + CRC32 crc_hash; + + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, typename LayoutA, + typename ElementB, typename LayoutB, + typename ElementC, typename LayoutC, + typename ElementAccumulator, + typename ElementCompute +> +inline CachedTestKey CreateCachedConv3dTestKey( + cutlass::conv::Operator conv_operator, + cutlass::conv::Conv3dProblemSize const &problem, + ElementCompute alpha, + ElementCompute beta, + cutlass::TensorView A, + cutlass::TensorView B, + cutlass::TensorView C +) { + + CachedTestKey key; + + // Encode conv3d operator and problem sizes + key.op = "conv3d"; + + std::stringstream ss_problem; + + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape, + typename ElementA, + typename ElementB, + typename ElementC, + typename ElementD +> +inline CachedTestKey CreateCachedConvNd3xTestKey( + cutlass::conv::Operator conv_operator, + ProblemShape const& problem_shape, + double alpha, + double beta, + thrust::universal_vector A, + thrust::universal_vector B, + thrust::universal_vector C +) { + + CachedTestKey key; + + // Encode convNd operator and problem sizes + std::stringstream ss_op; + ss_op << "conv" << ProblemShape::RankS << "d"; + key.op = ss_op.str(); + + std::stringstream ss_problem; + ss_problem << EncodeOperator(conv_operator) << "_"; + EncodeProblemSize(ss_problem, problem_shape); + ss_problem << "_alpha" << EncodeScalar(alpha) << "_beta" << EncodeScalar(beta); + key.problem = ss_problem.str(); + + // Encode problem data types + std::stringstream ss_types; + EncodeTypes< + ElementA, + ElementB, + ElementC, + ElementD>(ss_types); + key.types = ss_types.str(); + + // Encode problem data + CRC32 crc_hash; + key.A = TensorHash(A, crc_hash); + key.B = TensorHash(B, crc_hash); + key.C = TensorHash(C, crc_hash); + + return key; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..a14134b2854732e669977831207a456d28beed9f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_problems.h @@ -0,0 +1,927 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv2dProblemVector = std::vector; + +// +// Structures to prune items from Conv2dProblemVector +// +// Specification template for pruning items for convolution problem lists +template struct Specification +{ + virtual ~Specification() = default; + virtual bool is_satisfied(T item) const = 0; +}; + +// input size (NHWC) specification +struct InputSizeSpecification : Specification +{ + cutlass::Tensor4DCoord input_size; + + InputSizeSpecification(cutlass::Tensor4DCoord input_size_) : input_size(input_size_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((input_size.n() == item.N) && (input_size.h() == item.H) && (input_size.w() == item.W) && (input_size.c() == item.C)); + } +}; + +// stride (stride_h, stride_w) specification +struct StrideSpecification : Specification +{ + cutlass::MatrixCoord stride; + + StrideSpecification(cutlass::MatrixCoord stride_) : stride(stride_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((stride.row() == item.stride_h) && (stride.column() == item.stride_h)); + } +}; + +// channel (C,K) specification, must be multiple of minimum channel +struct ChannelDivisibilitySpecification : Specification +{ + int channel_multiple; + + ChannelDivisibilitySpecification(int channel_multiple_) : channel_multiple(channel_multiple_) {} + + bool is_satisfied(cutlass::conv::Conv2dProblemSize item) const override { + return ((item.K % channel_multiple == 0) && (item.C % channel_multiple == 0)); + } +}; + +// +// Pruning function for items from Conv2dProblemVector based on a Specification +// +inline Conv2dProblemVector prune(Conv2dProblemVector const &items, + Specification const &spec) +{ + Conv2dProblemVector pruned_list; + + for (auto& p : items) + if (spec.is_satisfied(p)) + pruned_list.push_back(p); + return pruned_list; +} + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv2dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv2dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + + Conv2dProblemVector conv2d_default_sizes; + Conv2dProblemVector conv2d_rigorous_sizes; + Conv2dProblemVector conv2d_resnet50_sizes; + Conv2dProblemVector conv2d_resnet50_sizes_perf; + + // + // Methods + // + /// Default ctor + TestbedConv2dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + initialize_conv2d_default_sizes(); + initialize_conv2d_rigorous_sizes(); + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes, 1 /*batch-size*/); + + initialize_conv2d_resnet50_sizes(conv2d_resnet50_sizes_perf, 34 /*batch-size*/); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &conv2d_default_sizes, + &conv2d_rigorous_sizes, + &conv2d_resnet50_sizes, + &conv2d_resnet50_sizes_perf + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (1,1) asymmetric paddings (1, 0, 1, 0) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 1, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 1, 8, minimum_channel_size}, // input size (NHWC) + {8, 1, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 8, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 4, 4, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {2, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 5, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 5, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 6, 6, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 7, 9, minimum_channel_size}, // input size (NHWC) + {8, 7, 7, minimum_channel_size}, // filter size (KRSC) + {1, 0, 1, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////////////// + // Small input size x stride (2,2) + // C < CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 11, 7, minimum_channel_size}, // input size (NHWC) + {8, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 11, minimum_channel_size}, // input size (NHWC) + {8, 1, 1, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 17, 19, minimum_channel_size}, // input size (NHWC) + {16, 2, 2, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 5, minimum_channel_size}, // input size (NHWC) + {16, 3, 3, minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 17, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 8}, // input size (NHWC) + {24, 3, 3, 8}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 20, 24, 8}, // input size (NHWC) + {40, 3, 3, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size (1x16x16x128), filter size (1x1, 2x2, 3x3, 5x5), stride (1, 1) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 160}, // input size (NHWC) + {224, 1, 1, 160}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 19, 37, 160}, // input size (NHWC) + {224, 3, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 16, 160}, // input size (NHWC) + {224, 2, 3, 160}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 23, 21, 128}, // input size (NHWC) + {224, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 29, 37, 160}, // input size (NHWC) + {224, 5, 5, 160}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // C > CTA::K and non-multiples of CTA::K. Typical CTA::K = {32, 64} + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 15, 19, 32 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 32 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 16, 24, 64 + minimum_channel_size}, // input size (NHWC) + {96, 3, 3, 64 + minimum_channel_size}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size, filter size (1x1, 3,x3, 5x5, 7x7), stride (2, 2) + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 13, 16, 288}, // input size (NHWC) + {160, 5, 5, 288}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 55, 51, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 71, 80, 32}, // input size (NHWC) + {64, 5, 5, 32}, // filter size (KRSC) + {2, 2, 2, 2}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 224, 224, 8}, // input size (NHWC) + {64, 7, 7, 8}, // filter size (KRSC) + {3, 3, 3, 3}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size stride (3, 3), filter (3, 3), non-default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 23, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {3, 3}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size padding > stride, asymmetric filter, padding and striding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 31, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {5, 5, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 4}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 35, 256}, // input size (NHWC) + {512, 7, 5, 256}, // filter size (KRSC) + {11, 11, 7, 7}, // padding (pad_h, _, pad_w, _) + {3, 5}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + //////////////////////////////////////////////////////////////////////////////////// + // Medium input size *mixed* stride (1, 2) and (2, 1), + // filter (3, 3), default padding + //////////////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 27, 27, 256}, // input size (NHWC) + {512, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + ///////////////////////////////////////////////////////////////////////////// + // Additional input size + ///////////////////////////////////////////////////////////////////////////// + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {3, 28, 28, 256}, // input size (NHWC) + {256, 2, 2, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 32, 32, 16}, // input size (NHWC) + {32, 3, 3, 16}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {6, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {32, 24, 32, 32}, // input size (NHWC) + {32, 1, 2, 32}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 4, 5, 128}, // input size (NHWC) + {256, 3, 6, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 3, 3, 256} // output size (NPQK) + )); + + conv2d_default_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {4, 2, 3, 256}, // input size (NHWC) + {328, 3, 5, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + {4, 1, 1, 328} // output size (NPQK) + )); + } + + + // Add a few large and rigorous convolution problem sizes + void initialize_conv2d_rigorous_sizes() { + +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 124, 224, 96}, // input size (NHWC) + {24, 7, 7, 96}, // filter size (KRSC) + {1, 229, 129, 32} // output size (NPQK) + )); + + conv2d_rigorous_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 233, 35, 48}, // input size (NHWC) + {24, 7, 5, 48}, // filter size (KRSC) + {1, 233, 35, 24} // output size (NPQK) + )); + +#endif + + } + + + // Add resent50 layers to unit testing sizes + void initialize_conv2d_resnet50_sizes(Conv2dProblemVector &conv2d_problem_vector, int batch_size = 1){ + +#if 0 // Resnet50 first layer (layer_id = 0) with channel = 3 is not supported in cutlass + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + [1, 224, 224, 3], // input size (NHWC) + [64, 7, 7, 3], // filter size (KRSC) + [3, 3, 3, 3], // padding (pad_h, _, pad_w, _) + [2, 2], // stride (stride_h, stride_w) + [1, 1], // dilation (dilation_h, dilation_w) + )); +#endif + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {256, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 1, 1, 64}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 64}, // input size (NHWC) + {64, 3, 3, 64}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {64, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {512, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 56, 56, 256}, // input size (NHWC) + {128, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {128, 3, 3, 128}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 128}, // input size (NHWC) + {512, 1, 1, 128}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {128, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {1024, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 28, 28, 512}, // input size (NHWC) + {256, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {256, 3, 3, 256}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 256}, // input size (NHWC) + {1024, 1, 1, 256}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {256, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {2048, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 14, 14, 1024}, // input size (NHWC) + {512, 1, 1, 1024}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {512, 3, 3, 512}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 512}, // input size (NHWC) + {2048, 1, 1, 512}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + + conv2d_problem_vector.push_back(cutlass::conv::Conv2dProblemSize( + {batch_size, 7, 7, 2048}, // input size (NHWC) + {512, 1, 1, 2048}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + )); + } + +}; + + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedGroupConv2dProblemSizes initializes and holds group conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedGroupConv2dProblemSizes { + + // + // Data members + // + int threadblock_n; + int threadblock_k; + int minimum_channel_size; + + Conv2dProblemVector default_single_group_sizes; + Conv2dProblemVector default_multiple_group_sizes; + + // + // Methods + // + /// Default ctor + TestbedGroupConv2dProblemSizes( + int threadblock_n_, + int threadblock_k_, + int minimum_channel_size_ = 64) + : threadblock_n (threadblock_n_), + threadblock_k (threadblock_k_), + minimum_channel_size (minimum_channel_size_) { + initialize_group_conv2d_default_sizes(); + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv2dProblemVector *problems_vectors[] = { + &default_single_group_sizes, + &default_multiple_group_sizes + }; + + for (Conv2dProblemVector *problems : problems_vectors) { + Conv2dProblemVector filtered; + + for (cutlass::conv::Conv2dProblemSize const & problem : *problems) { + if (!((problem.C / problem.groups) % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_group_conv2d_default_sizes() { + + //////////////////////////////////////////////////////////////////////////////////// + // One group calculated by one or multiple CTAs: k_per_group % CTA::N = 0 + // One CTA calculates a single group + //////////////////////////////////////////////////////////////////////////////////// + + for (int cta_per_group_k = 1; cta_per_group_k < 4; ++cta_per_group_k) { + // groups = 2, 3, 4 + for (int groups = 2; groups < 5; ++groups) { + + int conv_k = cta_per_group_k * threadblock_n * groups; + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2 * groups}, // input size (NHWC) + {conv_k, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + groups // groups + )); + + } // loop groups + } // loop cta_per_group_k + + // Partial gemm_k: k_per_group == CTA::N && channels_per_group < CTA::K + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n * 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // Larger problem sizes + + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 696}, // input size (NHWC) + {768, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + default_single_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 14, 14, 1392}, // input size (NHWC) + {1536, 3, 3, 232}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 3 // groups + )); + + //////////////////////////////////////////////////////////////////////////////////// + // One CTA calculate multiple groups: CTA::N % k_per_group = 0 + //////////////////////////////////////////////////////////////////////////////////// + + // 2 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 4}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 2 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k}, // input size (NHWC) + {threadblock_n, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 2 // groups + )); + + // 4 groups per CTA + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 8}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k * 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + + // 4 groups per CTA and partial gemm_k + default_multiple_group_sizes.push_back(cutlass::conv::Conv2dProblemSize( + {1, 8, 8, threadblock_k * 2}, // input size (NHWC) + {threadblock_n / 2, 3, 3, threadblock_k / 2}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + cutlass::conv::Mode::kCrossCorrelation, + 1, // split_k_slices + 4 // groups + )); + } + +}; + + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..34588ecb467b824cc0fcbbff0bc0d99e4385d80e --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed.h @@ -0,0 +1,818 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + +public: + + TestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificConv2d( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + // Vectors of Conv2dProblemVector (lenient/easiest to rigorous problem sizes) + std::vector problem_vectors = { + conv_test_sizes, // run user specified sizes + conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + //conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Flatten 2D problem_vectors into a 1D problem_sizes + std::vector problem_sizes; + for (auto problem_vector : problem_vectors) { + for(auto conv_problem : problem_vector) { + problem_sizes.push_back(conv_problem); + } + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reverse the order (rigorous to lenient) + // run the most rigorous problem size first + if (CutlassUnitTestProblemCount()) { + std::reverse(problem_sizes.begin(), problem_sizes.end()); + } + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // Fixed channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + if (conv_problem.C != ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // Few channels algorithm requires channel count to match access size + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + if (conv_problem.C % ImplicitGemm::UnderlyingKernel::Mma::IteratorA::AccessType::kElements) { + continue; + } + } + + // CUTLASS DGRAD's *strided* stride specialization supports all stride {stride_h, stride_w} + // Although strided dgrad works for all stride combinations, we are only going + // to run strided dgrad for non-unity strides + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFixedChannels) { + + return true; + } + + // Small-channels convolution can't run here. + if (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kIteratorAlgorithm == + cutlass::conv::IteratorAlgorithm::kFewChannels) { + + return true; + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}) // dilation (dilation_h, dilation_w) + .reset_split_k_slices(2), + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + + // If CUTLASS_UNIT_TEST_PROBLEM_COUNT is set reduce the number of tested problem counts + if (CutlassUnitTestProblemCount() && + testbed.tested_problem_count > CutlassUnitTestProblemCount()) { + return true; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..cf075674da673cf8e056172732f912b8acba3c5b --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_testbed_interleaved.h @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class InterleavedTestbedConv2d { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_B_reordered; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + InterleavedTestbedConv2d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 3; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_B_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + cutlass::reorder_convK( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv2dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "ncxhwx_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_cxrskx_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllInterleavedConv2d( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + InterleavedTestbedConv2d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(InterleavedK); // minimum channel size must be multiple of InterleavedK for interleaved layout + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + ChannelDivisibilitySpecification channel_spec(InterleavedK); //input and output channels must be multiple of InterleavedK + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + +#if 0 + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } +#endif + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..ad7b2ce61a66a79f852c0aac0895d10ba18e5466 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_absmax_testbed.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level Conv2Ds with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "conv2d_problems.h" +#include "../../common/cutlass_unit_test.h" +#include "../../gemm/device/testbed_utils.h" + +#include "cutlass/matrix_coord.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_reduce.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Conv, + template class ActivationFunctor +> +struct TestbedConv2dWithAbsMax { + + using ElementAccumulator = typename Conv::ElementAccumulator; + using ElementCompute = typename Conv::UnderlyingKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Conv::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Conv::EpilogueOutputOp::ElementAbsmax; + static cutlass::conv::Operator const kConvolutionalOperator = Conv::kConvolutionalOperator; + + static bool const kScaleAux = Conv::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Conv::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedConv2dWithAbsMax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Vector.resize({1, 1, 1, implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c()}); + reference_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + tmp_D.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<4> origin(0); + tensor_A.host_view().at(origin) = typename Conv::ElementA(1); + tensor_B.host_view().at(origin) = typename Conv::ElementB(1); + tensor_C.host_view().at(origin) = typename Conv::ElementC(1); + tensor_Vector.host_view().at(origin) = typename Conv::ElementC(1); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1, 1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1, 1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1, 1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size), false); + reference_abs_max_Aux.resize({1, 1, 1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); + passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file0("conv_testbed_with_amax_errors_reference.txt"); + std::ofstream file1("conv_testbed_with_amax_errors_computed.txt"); + + std::ofstream file("conv_testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << std::endl; + + file0 << "\n\nReference D =\n" << reference_D.host_view() << std::endl; + file1 << "\n\nComputed D =\n" << tensor_D.host_view() << std::endl; + if (kScaleAux) { + file0 << "\n\nReference Aux =\n" << reference_Aux.host_view() << std::endl; + file1 << "\n\nComputed Aux =\n" << tensor_Aux.host_view() << std::endl; + file0 << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() << std::endl; + file1 << "\n\nComputed Absmax Aux = " << abs_max_Aux.host_view() << std::endl; + } + if (kScaleOutput) { + file0 << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() << std::endl; + file1 << "\n\nComputed Absmax D = " << abs_max_D.host_view() << std::endl; + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<4> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + cutlass::reference::host::Conv2d< + typename Conv::ElementA, typename Conv::LayoutA, + typename Conv::ElementB, typename Conv::LayoutB, + typename Conv::ElementC, typename Conv::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tmp_D.host_ref(), + scaled_alpha, + scaled_beta + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({n, p, q, k})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, 0, 0, k})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({n, p, q, k}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({n, p, q, k}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + } + } + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Conv::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Conv::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + typename Conv::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + tensor_Aux.device_ref(), + epilogue_params, + cutlass::conv::SplitKMode::kSerial, + tensor_Vector.device_data(), + 0 + }; + + Conv conv2d_op; + + cutlass::Status status = conv2d_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Conv::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv2d_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed" << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ImplicitGemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllConv2dWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(); + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector(); + + // + // Testbed object + // + + TestbedConv2dWithAbsMax testbed(scaleA, scaleB, scaleC); + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + bool passed = true; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Prune all problems with channels that aren't divisible by the number of elements accessed per + // load for operands A and B. This is meant to align with the requirements of iterators used for + // fprop kernels. + ChannelDivisibilitySpecification channel_spec(128 / cutlass::sizeof_bits::value); + auto pruned_problem_vector = prune(*problem_vector, channel_spec); + + // Run conv testbed on default convolution sizes + for(auto conv_problem : pruned_problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed &= testbed.run(conv_problem); + + if (!passed) { + return false; + } + + // test mode = convolution + passed &= testbed.run(conv_problem.reset_mode(cutlass::conv::Mode::kConvolution)); + + if (!passed) { + return false; + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f768f5b25f425910a49058599d3854352136caef --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_broadcast_testbed.h @@ -0,0 +1,734 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv2dWithBroadcastReferenceOp { + + using OutputOp = typename Conv2d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv2dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv2d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv2d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, p, q, k] = ReductionOp(Y[n, p, q, k], Broadcast[k]) +// +// Z[n, p, q, k] = Elementwise(T[n, p, q, k]) +// + +template < + typename Conv2d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv2dWithBroadcast { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv2dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_C_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Z_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Broadcast.resize({ + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, p, q, k}) = ElementAccumulator(tensor_C.at({n, p, q, k})); + } + } + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, p, q, k}) = z; + tensor_T_reference.at({n, p, q, k}) = t; + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv2dWithBroadcast( + const Conv2dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv2dWithBroadcast( + const Conv2dProblemVector &conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector &conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..a8ec16ca5de369470f5dc50bb6f8b5e2da3da10d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv2d_with_reduction_testbed.h @@ -0,0 +1,643 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/tensor_reduce.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv2d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv2dWithReduction { +public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + using ElementT = typename EpilogueOutputOp::ElementTensor; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_Final_Reduction; + + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv2dWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope = 2; + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + tensor_Reduction.resize({ + 1, + 1, + (problem_size.N * problem_size.P * problem_size.Q - 1 + Conv2d::ThreadblockShape::kM) / Conv2d::ThreadblockShape::kM, + (problem_size.K) + }); + + tensor_Final_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + tensor_Tensor.resize({(problem_size.N * problem_size.P * problem_size.Q), problem_size.K}); + + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv2d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + static_cast(tensor_Reduction.stride()[0]), + static_cast(tensor_Tensor.stride()[0]) + ); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv2d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv2d output is written to workspace in global memory + conv2d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv2d_args.output_op = {ElementCompute(1), ElementCompute(0)}; + // update conv2d operator arguments + status = conv2d_op.update(conv2d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + // Final reduction over the partial reduction tensor + using Functor = cutlass::plus; + using TensorReduction = cutlass::reduction::device::TensorReduction< + ElementAccumulator, + ElementAccumulator, + LayoutC, + Functor, + 8, + ElementAccumulator + >; + + TensorReduction reduction(tensor_Reduction.extent(), 2); + + cutlass::DeviceAllocation reduction_device_workspace(reduction.workspace_size()); + + status = reduction.reduce( + tensor_Final_Reduction.device_ref(), + tensor_Reduction.device_ref(), + reduction_device_workspace.get(), + ElementAccumulator()); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + EXPECT_EQ(cudaDeviceSynchronize(), cudaSuccess); + + // + // Reference check + // + + tensor_D_computed.sync_host(); + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + + EXPECT_TRUE(passed); + + // + // Reference check on reduction results + // + + tensor_Reduction.sync_host(); + tensor_Final_Reduction.sync_host(); + + // compute backwards for reduction results + cutlass::HostTensor reference_Reduction; + reference_Reduction.resize({ + 1, + 1, + 1, + (problem_size.K) + }); + + for (int k = 0; k < problem_size.K; ++k) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + reduced_value += tensor_D_reference.at({n, p, q, k}); + } + } + } + reference_Reduction.at({0, 0, 0, k}) = reduced_value; + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Final_Reduction.host_view(), + reference_Reduction.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << tensor_D_reference.host_view() << "\n" + << "\nD computed:\n" << tensor_D_computed.host_view() << "\n" + << "\nreduction reference:\n" << reference_Reduction.host_view() << "\n" + << "\nreduction computed:\n" << tensor_Reduction.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv2d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestAllConv2dWithReduction( + const Conv2dProblemVector & conv_test_sizes = Conv2dProblemVector(), + const Conv2dProblemVector & conv_blacklist_sizes = Conv2dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv2dWithReduction testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv2dProblemSizes conv_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv2d problem sizes to avoid duplicate runs + Conv2dProblemVector conv_tested_sizes; + + Conv2dProblemVector const *problem_vectors[] = { + &conv_test_sizes, // run user specified sizes + &conv_problems.conv2d_default_sizes, // run default and cudnn bug sizes + &conv_problems.conv2d_resnet50_sizes, // run resnet50 sizes +#if CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED + &conv_problems.conv2d_rigorous_sizes, // run large and rigorous sizes if enabled +#endif + }; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv2dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // CUTLASS DGRAD's *strided* specialization does not support split-k mode + if ((ImplicitGemm::kConvolutionalOperator == + cutlass::conv::Operator::kDgrad) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + + passed = testbed.run( + cutlass::conv::Conv2dProblemSize( + {1, 56, 56, 8}, // input size (NHWC) + {8, 1, 1, 8}, // filter size (KRSC) + {0, 0, 0, 0}, // padding (pad_h, _, pad_w, _) + {2, 2}, // stride (stride_h, stride_w) + {1, 1}), // dilation (dilation_h, dilation_w) + cutlass::conv::SplitKMode::kSerial, + cutlass::from_real(2.0), + cutlass::from_real(2.0)); + + if (!passed) { + return false; + } + + return passed; + } + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv2dProblemSize conv2d_split_k_test_size ( + {1, 17, 11, 288}, // input size (NHWC) + {160, 3, 3, 288}, // filter size (KRSC) + {1, 1, 1, 1}, // padding (pad_h, _, pad_w, _) + {1, 1}, // stride (stride_h, stride_w) + {1, 1} // dilation (dilation_h, dilation_w) + ); + + // Parallel SplitK is not tested. + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv2d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h new file mode 100644 index 0000000000000000000000000000000000000000..fae7d6194fb671594221a90faea7cac1e5fbeb9f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_problems.h @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed sizes for Conv2d problem +*/ +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace test { +namespace conv { +namespace device { + +using Conv3dProblemVector = std::vector; + +//////////////////////////////////////////////////////////////////////////// +/// Structure TestbedConv3dProblemSizes initializes and holds conv default and +/// important network sizes +//////////////////////////////////////////////////////////////////////////// +struct TestbedConv3dProblemSizes { + + // + // Data members + // + int minimum_channel_size; + Conv3dProblemVector conv3d_default_sizes; + Conv3dProblemVector conv3d_vnet_medical_sizes; + + // + // Methods + // + /// Default ctor + TestbedConv3dProblemSizes(int minimum_channel_size_ = 64): minimum_channel_size (minimum_channel_size_) { + + initialize_conv3d_default_sizes(); + initialize_conv3d_vnet_medical_sizes(conv3d_vnet_medical_sizes, 1 /*batch-size*/); + + filter_all(); + } + + /// Eliminates some illegal cases + void filter_all() { + + Conv3dProblemVector *problems_vectors[] = { + &conv3d_default_sizes, + &conv3d_vnet_medical_sizes + }; + + for (Conv3dProblemVector *problems : problems_vectors) { + Conv3dProblemVector filtered; + + for (cutlass::conv::Conv3dProblemSize const & problem : *problems) { + if (!(problem.C % minimum_channel_size)) { + filtered.push_back(problem); + } + } + + *problems = filtered; + } + } + + // Add a few standard convolution problem sizes + void initialize_conv3d_default_sizes() { + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 3, 3, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 1, 8, minimum_channel_size}, // input size (NDHWC) + {8, 1, 1, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 8, 8, 8, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + CUTLASS_STL_NAMESPACE::make_tuple( + cutlass::Coord<3>({1, 1, 1}), // near padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({0, 0, 0}) // far padding (pad_d, pad_h, pad_w) + ), + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 16, 16, 16, minimum_channel_size}, // input size (NDHWC) + {8, 3, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 15, 19, 160}, // input size (NDHWC) + {224, 1, 3, 6, 160}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 2, 1, 1, minimum_channel_size}, // input size (NDHWC) + {8, 2, 1, 1, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 1, 7, 7, minimum_channel_size}, // input size (NDHWC) + {16, 1, 3, 3, minimum_channel_size}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_default_sizes.push_back(cutlass::conv::Conv3dProblemSize( + {1, 11, 15, 19, 64}, // input size (NDHWC) + {32, 4, 3, 6, 64}, // filter size (KTRSC) + cutlass::Coord<3>({2, 1, 3}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + } + + // Add vnet layers to unit testing sizes + void initialize_conv3d_vnet_medical_sizes(Conv3dProblemVector &conv3d_problem_vector, int batch_size = 1) { + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {32, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {64, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 64}, // input size (NDHWC) + {128, 2, 2, 2, 64}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 4, 4, 4, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 8, 8, 8, 128}, // input size (NDHWC) + {128, 3, 3, 3, 128}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 64}, // input size (NDHWC) + {64, 3, 3, 3, 64}, // filter size (KTRSC) + cutlass::Coord<3>({1, 1, 1}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 32, 32, 32, 16}, // input size (NDHWC) + {64, 2, 2, 2, 16}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + + conv3d_problem_vector.push_back(cutlass::conv::Conv3dProblemSize( + {batch_size, 16, 16, 16, 32}, // input size (NDHWC) + {128, 2, 2, 2, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({2, 2, 2}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + )); + + } + +}; + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..029f5effb9103bebd4ee61767795d3883541d986 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_testbed.h @@ -0,0 +1,716 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/util/reference/host/convolution.h" + +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "conv3d_problems.h" +#include "cutlass/core_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedConv3d { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + + /// Reduction kernel + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + + using ReductionDevice = cutlass::reduction::device::ReduceSplitK; + using ReductionStrideIndex = typename ReductionDevice::StrideIndex; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + +public: + + TestbedConv3d( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, uint64_t seed = 2019) { + + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute()) { + + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << float(alpha) << ", " << float(beta) << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv3d conv3d_op; + + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + split_k_mode + ); + + cutlass::Status status = conv3d_op.can_implement(conv3d_args); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_size: \n"; + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // conv3d operation with parallel split-k-mode + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // conv3d output is written to workspace in global memory + conv3d_args.ref_D.reset(reinterpret_cast(workspace.get())); + // accumulate mma for each cta in k-dimension (1.0 * A * B) + conv3d_args.output_op = {1.0, 0.0}; + // update conv3d operator arguments + status = conv3d_op.update(conv3d_args, workspace.get()); + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + if (split_k_mode == cutlass::conv::SplitKMode::kParallel) { + + // configure parallel reduction operator + ReductionDevice reduction_op; + + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_D_computed.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + { + tensor_C.device_data(), + ReductionStrideIndex(tensor_C.stride()[Conv3d::UnderlyingKernel::kTensorCStrideIdx]) + }, + // apply alpha, beta to obtain the following equation alpha * ReduceAdd(A * B) + beta * C + {alpha, beta} + ); + + status = reduction_op.initialize(reduction_args, nullptr); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + // run prallel reduction kernel + status = reduction_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + } + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConv3dTestKey< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view() + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv3d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta + ); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementCompute + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta + ); +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv3d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv3d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "ndhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_ktrsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv2dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv3d( + const Conv3dProblemVector & conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector & conv_blacklist_sizes = Conv3dProblemVector()) { + + bool passed = true; + + // + // Testbed object + // + + //TestbedConv3d testbed(cutlass::Distribution::Sequential, cutlass::Distribution::Sequential, cutlass::Distribution::Sequential); + TestbedConv3d testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's unity stride specialization only support stride {1, 1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + ((ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity) || + (ImplicitGemm::UnderlyingKernel::Mma::IteratorB::kStrideSupport == + cutlass::conv::StrideSupport::kUnity))) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + } + + // Sweep split-k-slice using serial reduction with non-unity alpha and non-zero beta for + // a single conv2d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial, + cutlass::conv::SplitKMode::kParallel + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template +bool TestSpecificConv3d( + const Conv3dProblemVector & problem_sizes) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3d testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..f8ba785c9d0ecbdd518711714558c9e166c0209a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/conv3d_with_broadcast_testbed.h @@ -0,0 +1,732 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM for fused epilogue broadcast testbed + + Parallel split-k is not tested because we can just use regular conv kernel + when we need to use parallel-splitk. Broadcast can happen in the reduction + kernel. +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "conv3d_problems.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "../cache_testbed_output.h" + +namespace test { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Conv3dWithBroadcastReferenceOp { + + using OutputOp = typename Conv3d::EpilogueOutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + Conv3dWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute conv3d, ElementCompute bias) { + ElementCompute t_full = binary_op(conv3d, bias); + T = ElementT(t_full); + + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = CONV(AB, C) +// +// T[n, o, p, q, k] = ReductionOp(Y[n, o, p, q, k], Broadcast[k]) +// +// Z[n, o, p, q, k] = Elementwise(T[n, o, p, q, k]) +// + +template < + typename Conv3d, + typename ReferenceOp, + bool AddBroadcastFirst = false +> +class TestbedConv3dWithBroadcast { +public: + + using ElementA = typename Conv3d::ElementA; + using LayoutA = typename Conv3d::LayoutA; + using ElementB = typename Conv3d::ElementB; + using LayoutB = typename Conv3d::LayoutB; + using ElementC = typename Conv3d::ElementC; + using LayoutC = typename Conv3d::LayoutC; + using ElementAccumulator = typename Conv3d::ElementAccumulator; + using ElementCompute = typename Conv3d::ElementCompute; + using EpilogueOutputOp = typename Conv3d::EpilogueOutputOp; + using ElementZ = typename EpilogueOutputOp::ElementZ; + using ElementT = typename EpilogueOutputOp::ElementT; + using ElementVector = typename EpilogueOutputOp::ElementVector; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv3d::kConvolutionalOperator; + static const bool kAddBroadcastFirst = AddBroadcastFirst; + static const bool kStoreT = EpilogueOutputOp::kStoreT; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_C_reference; + cutlass::HostTensor tensor_Z_computed; + cutlass::HostTensor tensor_Z_reference; + cutlass::HostTensor tensor_T_computed; + cutlass::HostTensor tensor_T_reference; + cutlass::HostTensor tensor_Y_reference; + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + +public: + + TestbedConv3dWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } + else { + scope = 5; + } + } + else { + scope = 8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + } + } + + void initialize( + cutlass::conv::Conv3dProblemSize const &problem_size, bool non_packed_test = false, uint64_t seed = 2019) { + + // to make the layout of tensors a little bit bigger than the problem size + cutlass::Tensor5DCoord stride_increment = cutlass::Tensor5DCoord(8, 16, 32, 32, 64); + + cutlass::Tensor5DCoord tensor_A_extent = implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_B_extent = implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size); + cutlass::Tensor5DCoord tensor_C_extent = implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size); + + if (non_packed_test) { + tensor_A_extent += stride_increment; + tensor_C_extent += stride_increment; + } + + tensor_A.resize(tensor_A_extent); + tensor_B.resize(tensor_B_extent); + tensor_C.resize(tensor_C_extent); + tensor_C_reference.resize(tensor_C_extent); + tensor_Z_computed.resize(tensor_C_extent); + tensor_Z_reference.resize(tensor_C_extent); + tensor_T_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_T_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_Y_reference.resize(tensor_C_extent); + tensor_Broadcast.resize({ + 1, + 1, + 1, + 1, + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c(), + }); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Broadcast.host_view(), init_C, seed * 39); + for (int n = 0; n < tensor_C_reference.extent().n(); ++n) { + for (int o = 0; o < tensor_C_reference.extent().d(); ++o) { + for (int p = 0; p < tensor_C_reference.extent().h(); ++p) { + for (int q = 0; q < tensor_C_reference.extent().w(); ++q) { + for (int k = 0; k < tensor_C_reference.extent().c(); ++k) { + tensor_C_reference.at({n, o, p, q, k}) = ElementAccumulator(tensor_C.at({n, o, p, q, k})); + } + } + } + } + } + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + tensor_C_reference.sync_device(); + tensor_Z_computed.sync_device(); + tensor_Z_reference.sync_device(); + tensor_T_computed.sync_device(); + tensor_T_reference.sync_device(); + tensor_Y_reference.sync_device(); + } + + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Conv3d::UnderlyingKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::conv::Conv3dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + bool non_packed_test = false, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 //display conv3d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") << std::endl + << std::endl; +#endif + + initialize(problem_size, non_packed_test); + + // configure the operator + Conv3d conv3d_op; + typename Conv3d::Arguments conv3d_args( + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_Z_computed.device_ref(), + {alpha, beta}, + split_k_mode, + tensor_Broadcast.device_data(), + kStoreT ? tensor_T_computed.device_data() : nullptr, + 0, // This must be zero + implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size).c() + ); + + // initialize the kernel + size_t workspace_size = Conv3d::get_workspace_size(conv3d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv3d_op.initialize(conv3d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv3d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " + << cudaGetErrorString(result); + + tensor_T_computed.sync_host(); + tensor_Z_computed.sync_host(); + + // + // Reference check + // + + // When kAddBroadcastFirst is true, add bias on the host + ElementCompute beta_ref = kAddBroadcastFirst ? ElementCompute(0) : beta; + +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C_reference.device_ref(), + tensor_Y_reference.device_ref(), + alpha, + beta_ref); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_Y_reference.sync_host(); + +#else + + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementAccumulator, + LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C_reference.host_ref(), + tensor_Y_reference.host_ref(), + alpha, + beta_ref); + +#endif + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int n = 0; n < problem_size.N; ++n) { + for (int o = 0; o < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Z : problem_size.D); ++o) { + for (int p = 0; p < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.P : problem_size.H); ++p) { + for (int q = 0; q < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.Q : problem_size.W); ++q) { + for (int k = 0; k < (kConvolutionalOperator == cutlass::conv::Operator::kFprop ? problem_size.K : problem_size.C); ++k) { + + ElementZ z{}; + ElementT t{}; + + ElementCompute accum = tensor_Y_reference.at({n, o, p, q, k}); + ElementCompute bias = ElementCompute(tensor_Broadcast.at({0, 0, 0, 0, k})); + + + if (kAddBroadcastFirst) { + reference_op(z, t, accum + bias, + beta * ElementCompute(tensor_C_reference.at({n, o, p, q, k}))); + } else { + reference_op(z, t, accum, bias); + } + + tensor_Z_reference.at({n, o, p, q, k}) = z; + tensor_T_reference.at({n, o, p, q, k}) = t; + } + } + } + } + } + + if (kStoreT) { + passed = cutlass::reference::host::TensorEquals( + tensor_T_computed.host_view(), + tensor_T_reference.host_view()); + + EXPECT_TRUE(passed); + } + + passed = cutlass::reference::host::TensorEquals( + tensor_Z_computed.host_view(), + tensor_Z_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv3d_ImplicitGemm_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : + (Conv3d::kConvolutionalOperator == cutlass::conv::Operator::kDeconv ? "deconv_" : "wgrad_"))) + << "nnhwc_" + << problem_size.N << "x" + << problem_size.D << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.T << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_d << "x" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_d << "x" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_d << "x" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_") + << (non_packed_test ? "non_packed_tensor_test_" : "packed_tensor_test_") + << Conv3d::ThreadblockShape::kM << "x" + << Conv3d::ThreadblockShape::kN << "x" + << Conv3d::ThreadblockShape::kK << "_" + << Conv3d::WarpShape::kM << "x" + << Conv3d::WarpShape::kN << "x" + << Conv3d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nBroadcast:\n" << tensor_Broadcast.host_view() << "\n" + << "\nY reference:\n" << tensor_Y_reference.host_view() << "\n" + << "\nT reference:\n" << tensor_T_reference.host_view() << "\n" + << "\nT computed:\n" << tensor_T_computed.host_view() << "\n" + << "\nZ reference:\n" << tensor_Z_reference.host_view() << "\n" + << "\nZ computed:\n" << tensor_Z_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// +// TestAllConv: Runs cutlass::conv::device::ImplicitGemmConvolution operator and compares it with reference +// TestAllConv runs conv operator on default conv problem sizes from test::conv::device::TestbedConv3dProblemSizes +// Additionally, each conv3d test can provide conv problem sizes (conv_test_sizes) and blacklist of sizes +// (conv_blacklist_sizes) +///////////////////////////////////////////////////////////////////////////////////////////////////////////// +template , + bool AddBroadcastFirst = false, + bool TestSplitK = true +> +bool TestAllConv3dWithBroadcast( + const Conv3dProblemVector &conv_test_sizes = Conv3dProblemVector(), + const Conv3dProblemVector &conv_blacklist_sizes = Conv3dProblemVector(), + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // + // Get conv problem sizes to run conv operator + // + TestbedConv3dProblemSizes conv3d_problems(128/cutlass::sizeof_bits::value); + + // Vector of conv3d problem sizes to avoid duplicate runs + Conv3dProblemVector conv_tested_sizes; + + Conv3dProblemVector const *problem_vectors[] = { + &conv3d_problems.conv3d_default_sizes, + &conv3d_problems.conv3d_vnet_medical_sizes, + &conv_test_sizes + }; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (Conv3dProblemVector const * problem_vector : problem_vectors) { + + // Run conv testbed on default convolution sizes + for(auto conv_problem : *problem_vector) { + + // Skip blacklist and avoid duplicate problem sizes + if (std::find(conv_blacklist_sizes.begin(), conv_blacklist_sizes.end(), conv_problem) != conv_blacklist_sizes.end() || + std::find(conv_tested_sizes.begin(), conv_tested_sizes.end(), conv_problem) != conv_tested_sizes.end()) { + continue; + } + + // + // Procedurally disable certain cases + // + + // CUTLASS DGRAD's *unity* stride specialization only support stride {1, 1} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kUnity)) { + if (!((conv_problem.stride_d == 1) && + (conv_problem.stride_h == 1) && + (conv_problem.stride_w == 1)) + ) { + continue; + } + } + +#if 0 // relax restrictions on analytic strided dgrad + // CUTLASS DGRAD's *strided* specialization only support stride >= {2, 2} + if ((ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDgrad || + ImplicitGemm::kConvolutionalOperator == cutlass::conv::Operator::kDeconv) && + (ImplicitGemm::UnderlyingKernel::Mma::IteratorA::kStrideSupport == + cutlass::conv::StrideSupport::kStrided)) { + if (((conv_problem.stride_d == 1) && (conv_problem.stride_h == 1) && (conv_problem.stride_w == 1))) { + continue; + } + } +#endif + + // + // Test + // + // push back tested problem size to avoid re-running duplicates + conv_tested_sizes.push_back(conv_problem); + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + } + + if (!TestSplitK) + return passed; + + // Sweep split-k-slice using serial and prallel reduction with non-unity alpha and non-zero beta for + // a single conv3d problem size. Convolution unit tests take a long time to run so only sweep parameters + // which are abolutely necessary to catch functional bugs. The below code does provide option to sweep + // alpha and beta for local testing, but only runs one value for alpha and beta. + cutlass::conv::Conv3dProblemSize conv3d_split_k_test_size ( + {1, 8, 8, 8, 32}, // input size (NDHWC) + {32, 3, 3, 3, 32}, // filter size (KTRSC) + cutlass::Coord<3>({0, 0, 0}), // padding (pad_d, pad_h, pad_w) + cutlass::Coord<3>({1, 1, 1}), // stride (stride_d, stride_h, stride_w) + cutlass::Coord<3>({1, 1, 1}) // dilation (dilation_d, dilation_h, dilation_w) + ); + + cutlass::conv::SplitKMode split_k_modes [] = { + cutlass::conv::SplitKMode::kSerial + }; + + int split_k_slices[] = { + 1, 2, 3, 4, 201 + }; + + double problem_alpha[] = { + 2.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (auto split_k_mode : split_k_modes) { + for (auto split_k_slice : split_k_slices) { + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + passed = testbed.run( + conv3d_split_k_test_size.reset_split_k_slices(split_k_slice), + split_k_mode, + false,/*non_packed_test*/ + cutlass::from_real(alpha), + cutlass::from_real(beta)); + + if (!passed) { + return false; + } + } + } + } + } + + return passed; +} + +template , + bool AddBroadcastFirst = false> +bool TestSpecificConv3dWithBroadcast( + const Conv3dProblemVector & problem_sizes, + bool non_packed_test = false) { + + bool passed = true; + + // + // Testbed object + // + + TestbedConv3dWithBroadcast testbed; + + // Sweep conv3d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for(auto conv_problem : problem_sizes) { + + // + // Test + // + + // test mode = xcross, non_packed_test = false + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + + // test mode = convolution, non_packed_test = false + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial, non_packed_test); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..cef5f981c595dfbbb95658fb757865b219538192 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device/depthwise_conv2d_direct_conv_testbed.h @@ -0,0 +1,473 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Depthwise Direct Conv testbed +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" +#include "../cache_testbed_output.h" +#include "conv2d_problems.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/core_io.h" +#include "cutlass/cutlass.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace conv { +namespace device { + +template +class TestbedDepthwiseDirectConv2d { + public: + + using ElementA = typename Conv2d::ElementA; + using LayoutA = typename Conv2d::LayoutA; + using ElementB = typename Conv2d::ElementB; + using LayoutB = typename Conv2d::LayoutB; + using ElementC = typename Conv2d::ElementC; + using LayoutC = typename Conv2d::LayoutC; + using ElementAccumulator = typename Conv2d::ElementAccumulator; + using ElementCompute = typename Conv2d::ElementCompute; + using EpilogueOutputOp = typename Conv2d::EpilogueOutputOp; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d::kConvolutionalOperator; + + public: + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_reordered_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + int tested_problem_count; + + public: + TestbedDepthwiseDirectConv2d(cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_), tested_problem_count(0) {} + + /// Helper to initialize a tensor view + template + void initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } else if (bits == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope = 3; + } else { + scope = 5; + } + } else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, -scope, 0); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } else { + } + } + + void initialize(cutlass::conv::Conv2dProblemSize const &problem_size, uint64_t seed = 2019) { + tensor_A.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size)); + tensor_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_reordered_B.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size)); + tensor_C.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + tensor_D_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size)); + + initialize_tensor(tensor_A.host_view(), init_A, seed); + initialize_tensor(tensor_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_reordered_B.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C.host_view(), init_C, seed * 39); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_reordered_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_D_reference.sync_device(); + } + + bool sufficient(int smem_size) const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < static_cast(smem_size)) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::conv::Conv2dProblemSize const &problem_size, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha = ElementCompute(1.5), + ElementCompute beta = ElementCompute(1)) { + // increment tested problem count run by the testbed + tested_problem_count++; + +#if 0 // display conv2d problem size for debugging + std::cout << problem_size << std::endl + << "alpha, beta: (" << alpha << ", " << beta << ")" << std::endl + << "split_k_mode: " + << ((split_k_mode == cutlass::conv::SplitKMode::kSerial) ? "(serial)" : "(parallel)") + << std::endl + << std::endl; +#endif + + initialize(problem_size); + + // configure the operator + Conv2d conv2d_op; + + typename Conv2d::Arguments conv2d_args(problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_computed.device_ref(), + {alpha, beta}, + tensor_reordered_B.device_ref(), + split_k_mode); + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv2d::get_workspace_size(conv2d_args); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = conv2d_op.can_implement(problem_size); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + status = conv2d_op.initialize(conv2d_args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + if (!sufficient(conv2d_op.get_smem_size())) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // run conv2d operator + status = conv2d_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run." << std::endl; + return false; + } + + bool passed = false; + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " device reference error: " << cudaGetErrorString(result); + + tensor_D_computed.sync_host(); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = + CreateCachedConv2dTestKey(kConvolutionalOperator, + problem_size, + alpha, + beta, + tensor_A.host_view(), + tensor_B.host_view(), + tensor_C.host_view()); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string conv2d_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + } + + if (!cached_result_loaded) { +#if CUTLASS_CONV_TEST_UNIT_REFERENCE_DEVICE_ENABLED + + cutlass::reference::device::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D_reference.device_ref(), + alpha, + beta); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D_reference.sync_host(); + +#else + + cutlass::reference::host::Conv2d(kConvolutionalOperator, + problem_size, + tensor_A.host_ref(), + tensor_B.host_ref(), + tensor_C.host_ref(), + tensor_D_reference.host_ref(), + alpha, + beta); + +#endif + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + + cached_test_result.D = TensorHash(tensor_D_reference.host_view()); + + CachedTestResultListing cached_results(conv2d_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(conv2d_result_cache_name); + } + } // if (!cached_result_loaded) + + uint32_t tensor_D_hash = TensorHash(tensor_D_computed.host_view()); + + if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) { + passed = (tensor_D_hash == cached_test_result.D); + + EXPECT_EQ(tensor_D_hash, cached_test_result.D) + << "Hash-based comparison failed for key:" << "\n" << cached_test_key << "\n"; + } + else { + + passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view()); + } + + EXPECT_TRUE(passed); + + std::stringstream ss_problem_size_text; + ss_problem_size_text << "nhwc_" + << problem_size.N << "x" + << problem_size.H << "x" + << problem_size.W << "x" + << problem_size.C + << "_krsc_" + << problem_size.K << "x" + << problem_size.R << "x" + << problem_size.S << "x" + << problem_size.C + << "_padding_" + << problem_size.pad_h << "x" + << problem_size.pad_w + << "_stride_" + << problem_size.stride_h << "x" + << problem_size.stride_w + << "_dilation_" + << problem_size.dilation_h << "x" + << problem_size.dilation_w << "_" + << (problem_size.mode == cutlass::conv::Mode::kCrossCorrelation ? "xcorr_" : "conv_"); + + if (!passed) { + std::stringstream fname; + + fname << "error_Conv2d_DirectConv_device_" + << (split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial_reduction_" : "parallel_reduction_") + << (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kFprop ? "fprop_" : + (Conv2d::kConvolutionalOperator == cutlass::conv::Operator::kDgrad ? "dgrad_" : "wgrad_")) + << ss_problem_size_text.str() + << Conv2d::ThreadblockShape::kM << "x" + << Conv2d::ThreadblockShape::kN << "x" + << Conv2d::ThreadblockShape::kK << "_" + << Conv2d::WarpShape::kM << "x" + << Conv2d::WarpShape::kN << "x" + << Conv2d::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n"; + + results << "\nD reference (hash: " << cached_test_result.D << ")\n"; + + if (!cached_result_loaded) { + results + << tensor_D_reference.host_view() << "\n"; + } + + results + << "\nD computed (hash: " << tensor_D_hash << ")\n" + << tensor_D_computed.host_view() << "\n"; + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSpecificDepthwiseDirectConv2d(const Conv2dProblemVector &problem_sizes) { + bool passed = true; + + // + // Testbed object + // + TestbedDepthwiseDirectConv2d testbed; + + // Sweep conv2d problem sizes (split-k-mode=kSerial, split-k-slice=1, alpha=1.0, beta=0.0) + for (auto conv_problem : problem_sizes) { + // + // Test + // + + // test mode = xcross + passed = testbed.run( + conv_problem, + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + + // test mode = convolution + passed = testbed.run( + conv_problem.reset_mode(cutlass::conv::Mode::kConvolution), + cutlass::conv::SplitKMode::kSerial); + + if (!passed) { + return false; + } + } + + return true; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace conv +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..54c11281e14b813b249d7f9710542843b37bcc68 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/conv_problem_sizes.hpp @@ -0,0 +1,1385 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief CUTLASS 3.x Implicit GEMM testbed sizes for ConvNd problem +*/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::vector> +inline +get_conv_problem_vector(); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Fprop +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {800, 80, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {512, 64, 1}, // stride (nwc) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nqk) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 4, 64}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {8000, 800, 80, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {4096, 512, 64, 1}, // stride (nhwc) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (npqk) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 3, 3, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,2/1,2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {256, 2, 5, 64}, + {1, 1}, + {2, 2}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 7, 7, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D fprop problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kFprop>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // ndhwc + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, + {16, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, + {96, 1, 1, 1, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 3, 3, 64}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x3x3 + symmetric padding with c % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {96, 3, 4, 5, 64}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nwc + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, + {16,1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 64}, + {96, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 64}, + {256, 1, 64}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 4, 32}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and tstride of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {2}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 32}, + {256, 3, 32}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1024, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1040, 128}, + {640, 1, 128}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // nhwc + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, + {16, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 64}, + {96, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 64}, + {256, 1, 1, 64}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 3, 3, 32}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 15, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ stride, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 15, 32}, + {256, 2, 5, 32}, + {1, 1}, + {0, 0}, + {2, 3}, + {2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 64, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 65, 16, 128}, + {640, 1, 1, 128}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 64}, // ndhwc + {64, 1, 1, 1, 64}, // ktrsc + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // Filter 3x3x3 + no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 3, 3, 32}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {2, 2, 3}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 32}, + {96, 3, 4, 5, 32}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2048 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 64, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // To test streamk, equals to gemm-MxNxK size 128x640x2080 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 65, 16, 128}, + {640, 1, 1, 1, 128}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Grouped Wgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Get problem size vectors for group conv problems +template +std::vector> +inline +get_grouped_conv_problem_vector(int GroupsPerTile); + +// Specialization for 3D wgrad problems +template<> +std::vector> inline +get_grouped_conv_problem_vector<3, cutlass::conv::Operator::kWgrad>(int GroupsPerTile) { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + + if (GroupsPerTile == 1) { + // channel_per_group == 64 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 2048}, // ndhwc + {2048, 1, 3, 3, 64}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 2) { + // channel_per_group == 32 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 1024}, // ndhwc + {1024, 1, 3, 3, 32}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 4) { + // channel_per_group == 16 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 512}, // ndhwc + {512, 1, 3, 3, 16}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + else if (GroupsPerTile == 8) { + // channel_per_group == 8 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 16, 16, 256}, // ndhwc + {256, 1, 3, 3, 8}, // ktrsc + {0, 1, 1}, // padding lower (pad_d, pad_h, pad_w) + {0, 1, 1}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 32 // groups + }); + } + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Unit Stride Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 64}, // nqk + {512, 64, 1}, // stride (nqk) + {64, 1, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {1}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 16}, + {64, 1, 16}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 96}, + {64, 1, 96}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 256}, + {64, 1, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 3, 256}, + {0}, + {0}, + {1}, + {1}, + 1 + }); + // 3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {32, 3, 256}, + {1}, + {1}, + {1}, + {1}, + 1 + }); + // 4 filter, asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 256}, + {64, 4, 256}, + {0}, + {1}, + {1}, + {1}, + 1 + }); + // 3 filter, asymmetric padding and dilation of 2 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 64}, + {256, 3, 64}, + {0}, + {1}, + {1}, + {2}, + 1 + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {64, 1, 1, 64}, // krsc + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed input strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {8000, 800, 80, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // non-packed output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 64}, // npqk + {4096, 512, 64, 1}, // stride (npqk) + {64, 1, 1, 64}, // krsc + {64, 64, 64, 1}, // stride (krsc) + {8000, 800, 80, 1}, // stride (nhwc) + {0, 0}, // padding lower (pad_h, pad_w) + {0, 0}, // padding upper (pad_h, pad_w) + {1, 1}, // stride (stride_h, stride_w) + {1, 1}, // dilation (dilation_h, dilation_w) + 1 // group + }); + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 8, 8, 16}, + {64, 1, 1, 16}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 2 and K = 128 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 96}, + {64, 1, 1, 96}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // N = 7 and K = 256 for a even larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {7, 8, 8, 256}, + {64, 1, 1, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, no padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 3, 3, 256}, + {0, 0}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 3x3 filter, symmetric padding with k % cta_k !=0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {32, 3, 3, 256}, + {1, 1}, + {1, 1}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 8, 8, 256}, + {64, 2, 5, 256}, + {1, 1}, + {0, 0}, + {1, 1}, + {1, 1}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 1}, + {0, 0}, + {1, 1}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, false>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter-K = 16 for predication + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 16}, + {64, 1, 1, 1, 16}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // non-packed input output strides. + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1, 8, 8, 64}, // nzpqk + {8000, 8000, 800, 80, 1}, // stride (nzpqk) + {64, 1, 1, 1, 64}, // ktrsc + {64, 64, 64, 64, 1}, // stride (ktrsc) + {8000, 8000, 800, 80, 1}, // stride (ndhwc) + {0, 0, 0}, // padding lower (pad_d, pad_h, pad_w) + {0, 0, 0}, // padding upper (pad_d, pad_h, pad_w) + {1, 1, 1}, // stride (stride_d, stride_h, stride_w) + {1, 1, 1}, // dilation (dilation_d, dilation_h, dilation_w) + 1 // group + }); + // N = 7 and K = 256 for a larger grid + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 1, 8, 8, 96}, + {64, 1, 1, 1, 96}, + {0, 0, 0}, + {0, 0, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + symmetric padding 111 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010 + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 3, 5, 8, 96}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {1, 1, 1}, + 1 + }); + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {1, 1, 1}, + {2, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Strided Dgrad +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Specialization for 1D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<1, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Test TMA truncation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 512, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {2}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 1024, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {1, 2048, 64}, // nqk + {64, 1, 64}, // ksc + {0}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {8}, // stride (stride_w) + {1}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride divides dilation + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {0}, // padding lower (pad_w) + {1}, // padding upper (pad_w) + {2}, // stride (stride_w) + {4}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // dilation divides stride + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {0}, // padding upper (pad_w) + {4}, // stride (stride_w) + {2}, // dilation (dilation_w) + 1 // group + }); + // non-packed input/output strides. + // stride dilation dont divide + // asymmetric padding + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 8, 64}, // nqk + {800, 80, 1}, // stride (nqk) + {64, 3, 64}, // ksc + {64, 64, 1}, // stride (ksc) + {800, 80, 1}, // stride (nwc) + {1}, // padding lower (pad_w) + {2}, // padding upper (pad_w) + {2}, // stride (stride_w) + {3}, // dilation (dilation_w) + 1 // group + }); + return problem_shapes; +} + +// Specialization for 2D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<2, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 stride divides dilation + // mode 1 dilation divides stride + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {2, 4}, + {4, 2}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // mode 0 dilation divides stride + // mode 1 stride divides dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {4, 2}, + {2, 4}, + 1 + }); + // 2x5 filter, asymmetric padding 1,0/1,0, w/ dilation + // stride dilation dont divide + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {3, 16, 16, 64}, + {256, 2, 5, 64}, + {1, 0}, + {0, 1}, + {3, 2}, + {2, 3}, + 1 + }); + return problem_shapes; +} + +// Specialization for 3D dgrad problems +template<> +std::vector> inline +get_conv_problem_vector<3, cutlass::conv::Operator::kDgrad, true>() { + using ProblemShape = cutlass::conv::ConvProblemShape; + std::vector problem_shapes; + // Filter 3x4x5 + asymmetric padding 102/010, w/ dilation + problem_shapes.push_back({ + cutlass::conv::Mode::kCrossCorrelation, + {2, 16, 10, 16, 64}, + {64, 3, 4, 5, 96}, + {1, 0, 1}, + {0, 2, 0}, + {2, 1, 2}, + {4, 2, 3}, + 1 + }); + return problem_shapes; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..99ba9c407cec38e919812fedeee38ba75d9129f7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/conv/device_3x/testbed_conv.hpp @@ -0,0 +1,768 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implicit GEMM testbed for 3.x API +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "../../common/cutlass_unit_test.h" + +#include "cute/tensor.hpp" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "../test/unit/gemm/device/gemm_testbed_3x.hpp" + +#include "thrust/universal_vector.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/conv.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "conv_problem_sizes.hpp" +#include "../cache_testbed_output.h" + +#include + +#include "cute/layout.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test::conv::device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Initializes a flat device buffer +template +static void +initialize_values( + thrust::universal_vector& dst_ptr, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + if (cutlass::Distribution::Uniform == dist_kind) { + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 8) { + scope = 2; + } + else if (bits == 16) { + scope = 4; + } + else { + scope = 8; + } + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, scope, -scope, 0); + } + else if (cutlass::Distribution::Identity == dist_kind) { + cutlass::reference::host::BlockFillRandomUniform( + dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0, 0); + } + else if (cutlass::Distribution::Gaussian == dist_kind) { + cutlass::reference::host::BlockFillRandomGaussian(dst_ptr.data().get(), dst_ptr.size(), seed, 0, 0.5); + } + else if (cutlass::Distribution::Sequential == dist_kind) { + cutlass::reference::host::BlockFillSequential(dst_ptr.data().get(), dst_ptr.size()); + } + else { + std::cerr << "Invalid distribution kind!\n."; + exit(1); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// utils for sparse or dense conv parameters + +template +struct DenseConvParams { + // Default Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + + // get the default arguments without sparse data + auto get_mainloop_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + thrust::universal_vector& tensor_A, + thrust::universal_vector& tensor_B + ) { + auto args = typename Conv::ConvKernel::MainloopArguments { + tensor_A.data().get(), + tensor_B.data().get(), + }; + return args; + } +}; + +template +struct SparseConvParams { +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ConvTestbed { + // Kernel data types + using ElementA = typename Conv::ConvKernel::ElementA; + using ElementB = typename Conv::ConvKernel::ElementB; + using ElementC = cute::conditional_t, + typename Conv::ConvKernel::ElementD, typename Conv::ConvKernel::ElementC>; + using ElementD = typename Conv::ConvKernel::ElementD; + using ElementAccumulator = typename Conv::ConvKernel::ElementAccumulator; + + // ConvTest for sparse kernel + static constexpr bool isSparseEnabled = isSparseEnabled_; + using ConvParams = cute::conditional_t, DenseConvParams>; + ConvParams params; + + // + // FusionOperation derived types/queries + // + using FusionOp = typename Conv::EpilogueOutputOp; + + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementCompute = typename FusionOp::ElementCompute; + using BiasType = typename cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::type; + using ElementBias = non_void_t; + using ActivationType = non_void_t::type, + cutlass::epilogue::thread::Identity>; + static constexpr bool IsActivationEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithActivation::value; + using ActivationFunctor = cute::conditional_t>; + + static constexpr bool IsBiasEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithBias::value && + !cute::is_same_v; + static constexpr bool IsPerChannelScaleEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithPerChannelScaled::value; + + static constexpr bool DisableSource = cute::is_void_v; + + static constexpr bool IsResidualEnabled = cutlass::epilogue::collective::detail::IsThreadEpilogueOpWithResidualAdd::value; + + using StrideC = typename Conv::ConvKernel::StrideC; + using StrideD = typename Conv::ConvKernel::StrideD; + using ThreadEpilogueOp = typename Conv::ConvKernel::CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr cutlass::conv::Operator ConvOp = Conv::DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = Conv::NumSpatialDimensions; + using ProblemShape = cutlass::conv::ConvProblemShape; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + using Schedule = typename Conv::DispatchPolicy::Schedule; + /// Initialization + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_disable = cutlass::Distribution::Identity; // all zeros + uint64_t seed = 6090; + float epsilon = 0.0f; + int split_p_slices = 1; + thrust::universal_vector tensor_A; + thrust::universal_vector tensor_B; + thrust::universal_vector tensor_C; + thrust::universal_vector tensor_D_computed; + thrust::universal_vector tensor_D_reference; + thrust::universal_vector tensor_bias; + thrust::universal_vector tensor_alpha; + thrust::universal_vector tensor_beta; + + // Return true on success, else false + bool initialize(ProblemShape const& problem_shape, uint64_t seed = 6090) { + tensor_A.resize(sizeof(ElementA) * problem_shape.size_A()); + tensor_B.resize(sizeof(ElementB) * problem_shape.size_B()); + tensor_C.resize(sizeof(ElementC) * problem_shape.size_C()); + tensor_D_computed.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_D_reference.resize(sizeof(ElementD) * problem_shape.size_C()); + tensor_bias.resize(sizeof(ElementBias) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + if constexpr (IsPerChannelScaleEnabled) { + tensor_alpha.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + tensor_beta.resize(sizeof(ElementScalar) * cute::size(cute::get<0>(problem_shape.get_shape_B()))); + } + initialize_values(tensor_A, init_A, seed); + initialize_values(tensor_B, init_B, seed * 11); + initialize_values(tensor_C, init_C, seed * 17); + initialize_values(tensor_bias, init_bias, seed * 19); + if constexpr (IsPerChannelScaleEnabled) { + initialize_values(tensor_alpha, init_bias, seed * 23); + if constexpr (DisableSource) { + initialize_values(tensor_beta, init_disable, seed * 27); + } + else { + initialize_values(tensor_beta, init_bias, seed * 27); + } + } + + bool flag = true; + if constexpr (isSparseEnabled) { + flag &= params.initialize(problem_shape, tensor_B, static_cast(seed + 2023)); + } + + return flag; + } + + // Determine SMEM requirements and waive if not satisfied + bool sufficient() const { + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + int max_smem_size; + result = cudaDeviceGetAttribute(&max_smem_size, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaDeviceGetAttribute() failed"); + } + + return max_smem_size >= Conv::ConvKernel::SharedStorageSize; + } + + auto transform_shape_and_stride_with_groups(ProblemShape const& problem_shape) { + using TensorExtent = cute::array; + using TensorStride = cute::array; + + TensorExtent shape_a_g{}; + TensorExtent shape_b_g{}; + TensorExtent shape_c_g{}; + TensorStride stride_a_g{}; + TensorStride stride_b_g{}; + TensorStride stride_c_g{}; + + auto shape_a = cute::reverse(problem_shape.shape_A); + auto shape_b = cute::reverse(problem_shape.shape_B); + auto shape_c = cute::reverse(problem_shape.shape_C); + auto stride_a = cute::reverse(problem_shape.stride_A); + auto stride_b = cute::reverse(problem_shape.stride_B); + auto stride_c = cute::reverse(problem_shape.stride_C); + + int32_t G = problem_shape.groups; + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop || + ConvOp == cutlass::conv::Operator::kDgrad) { + // shape_a_g = (c,w,h,d,n,g) or (k,q,p,z,n,g) + // shape_b_g = (c,s,r,k,t,g) + // shape_c_g = (k,q,p,z,n,g) or (c,w,h,d,n,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_b), + cute::make_shape(cute::size(shape_b) / G, G))); + shape_c_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_c) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_c), + cute::make_shape(G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, + cute::size(stride_b) * cute::size(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, cute::size<0>(shape_c) / G)); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + // shape_a_g = (k,q,p,z,n,g) + // shape_b_g = (c,w,h,d,n,g) + // shape_c_g = (c,s,r,k,t,g) + shape_a_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_a) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_a), + cute::make_shape(G))); + shape_b_g = cute::to_array(tuple_cat( + cute::make_shape(cute::size<0>(shape_b) / G), + cute::take<1,NumSpatialDimensions + 2>(shape_b), + cute::make_shape(G))); + shape_c_g = cute::to_array(tuple_cat( + cute::take<0,NumSpatialDimensions + 1>(shape_c), + cute::make_shape(cute::size(shape_c) / G, G))); + + stride_a_g = cute::to_array(append(stride_a, cute::size<0>(shape_a) / G)); + stride_b_g = cute::to_array(append(stride_b, cute::size<0>(shape_b) / G)); + stride_c_g = cute::to_array(append(stride_c, + cute::size(stride_c) * cute::size(shape_c) / G)); + } + + return make_tuple(shape_a_g, shape_b_g, shape_c_g, + stride_a_g, stride_b_g, stride_c_g); + } + + // Executes one test + bool run( + ProblemShape const& problem_shape, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + MaxSwizzleSize max_swizzle = MaxSwizzleSize{}, + Splits splits = Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + ) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device.\n"; + } + return true; + } + + bool ret = initialize(problem_shape); + + if (!ret) { + std::cerr << "initialize failed for the given problem_shape: \n"; + return false; + } + + cutlass::KernelHardwareInfo hw_info; + cudaGetDevice(&hw_info.device_id); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + hw_info.cluster_shape = cluster_shape; + hw_info.cluster_shape_fallback = cluster_shape_fallback; + + // configure the operator + Conv conv_op; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, ConvOp); + } + // Need to support non-packed output strides for fprop and dgrad kernel. + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { + cute::get<0, i>(stride_C) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { + cute::get<0, i>(stride_D) = problem_shape.stride_C[ProblemShape::RankT-2-i]; + }); + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + typename Conv::ConvKernel::TileScheduler::Arguments scheduler_args{}; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + + auto mainloop_args = params.get_mainloop_arguments(problem_shape, tensor_A, tensor_B); + + auto epilogue_args = typename Conv::ConvKernel::EpilogueArguments { + {}, + tensor_C.data().get(), + stride_C, + tensor_D_computed.data().get(), + stride_D, + }; + + auto args = typename Conv::Arguments { + problem_shape, + mainloop_args, // MainloopArguments + epilogue_args, // EpilogueArguments + hw_info, + scheduler_args + }; + + auto &fusion_args = args.epilogue.thread; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + fusion_args.alpha_ptr = tensor_alpha.data().get(); + fusion_args.beta_ptr = tensor_beta.data().get(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = tensor_bias.data().get(); + } + + // Clamp bound + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::lowest(); + fusion_args.activation.upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); + } + + // Scale + if constexpr (cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> || + cute::is_same_v> ) { + fusion_args.activation.scale = ElementCompute{1}; + } + + // LeakyRelu + if constexpr (cute::is_same_v> ) { + fusion_args.activation.leaky_alpha = ElementCompute{0}; + } + + cutlass::Status status = cutlass::Status::kInvalid; + + status = conv_op.can_implement(args); + EXPECT_EQ(conv_op.can_implement(args), cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + std::cerr << "can_implement failed for the given problem_shape: \n"; + print(problem_shape); + return false; + } + + // find workspace requirement for parallel split-k reduction + size_t workspace_size = Conv::get_workspace_size(args); + thrust::universal_vector workspace(workspace_size); + + status = conv_op.initialize(args, workspace.data().get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // run conv3d operator + status = conv_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + if (status != cutlass::Status::kSuccess) { + return false; + } + + bool passed = false; + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " Kernel execution error: " + << cudaGetErrorString(result); + + // Create cute::Tensors using the logical rank-3 MNK multi-mode shapes the mainloop gives us + auto [shape_mA, shape_mB, shape_mC, stride_mA, stride_mB, stride_mC] = + transform_shape_and_stride_with_groups(problem_shape); + auto shape_mBias = cute::make_shape(cute::size(cute::get<0>(problem_shape.get_shape_B()))); + + auto mA = make_tensor(tensor_A.data().get(), make_layout(shape_mA, stride_mA)); + auto mB = make_tensor(tensor_B.data().get(), make_layout(shape_mB, stride_mB)); + auto mC = make_tensor(tensor_C.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_ref = make_tensor(tensor_D_reference.data().get(), make_layout(shape_mC, stride_mC)); + auto mD_computed = make_tensor(tensor_D_computed.data().get(), make_layout(shape_mC, stride_mC)); + auto mBias = make_tensor(tensor_bias.data().get(), make_layout(shape_mBias)); + auto mAlpha = make_tensor(tensor_alpha.data().get(), make_layout(shape_mBias)); + auto mBeta = make_tensor(tensor_beta.data().get(), make_layout(shape_mBias)); + + cutlass::reference::host::ConvEpilogueFusionParams< + ElementAccumulator, + ElementScalar, + ElementCompute, + ElementC, + ElementD, + IsResidualEnabled, + decltype(mAlpha), + decltype(mBeta), + decltype(mBias), + ActivationFunctor> + epilogue_fusion_params{}; + + epilogue_fusion_params.alpha = alpha; + epilogue_fusion_params.beta = beta; + + if constexpr (IsPerChannelScaleEnabled) { + epilogue_fusion_params.tensor_alpha = mAlpha; + epilogue_fusion_params.tensor_beta = mBeta; + } + + if constexpr (IsBiasEnabled) { + epilogue_fusion_params.tensor_bias = mBias; + } + + auto padding = cute::reverse(problem_shape.lower_padding); + auto tstride = cute::reverse(problem_shape.traversal_stride); + auto dilation = cute::reverse(problem_shape.dilation); + + cutlass::reference::host::ConvReferenceImpl< + ConvOp, + NumSpatialDimensions, + decltype(mA), + decltype(mB), + decltype(mC), + decltype(mD_ref), + decltype(padding), + decltype(tstride), + decltype(dilation), + decltype(epilogue_fusion_params)> + reference_impl(mA, mB, mC, mD_ref, padding, tstride, dilation, epilogue_fusion_params); + + // + // Reference check - support caching results + // + + CachedTestKey cached_test_key = CreateCachedConvNd3xTestKey< + ProblemShape, + ElementA, + ElementB, + ElementC, + ElementD + >( + ConvOp, + problem_shape, + alpha, + beta, + tensor_A, + tensor_B, + tensor_C + ); + + // + // Look for the cached key + // + + bool cached_result_loaded = false; + CachedTestResult cached_test_result; + + std::string convnd_result_cache_name = + std::string("cached_results_") + CUTLASS_TARGET_NAME + ".txt"; + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + CachedTestResultListing cached_results(convnd_result_cache_name); + + auto cached = cached_results.find(cached_test_key); + + cached_result_loaded = cached.first; + if (cached_result_loaded) { + cached_test_result = cached.second; + } + #endif + + if (!cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + cached_test_result.D = TensorHash(tensor_D_reference); + CachedTestResultListing cached_results(convnd_result_cache_name); + + cached_results.append(cached_test_key, cached_test_result); + cached_results.write(convnd_result_cache_name); + #endif + } // if (!cached_result_loaded) + + #if (CUTLASS_TEST_ENABLE_CACHED_RESULTS) + uint32_t tensor_D_computed_hash = TensorHash(tensor_D_computed); + passed = (tensor_D_computed_hash == cached_test_result.D); + // If hash fails, double check against reference implementation. + if(!passed) { + std::cerr << "Hash-based comparison unsuccessful for key:" << "\n" << cached_test_key + << ", comparing with reference implementation now.\n"; + if (cached_result_loaded) { + // Compute reference + reference_impl.compute_reference(); + } + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + } + #else + // Validate kernel against reference + passed = compare_reference(mD_ref, mD_computed, mA, mB, mAlpha, mBeta, mBias, this->epsilon); + #endif + + EXPECT_TRUE(passed); + return passed; + } + + template< + class Engine, class Layout, + class EngineA, class LayoutA, + class EngineB, class LayoutB, + class EngineAlpha, class LayoutAlpha, + class EngineBeta, class LayoutBeta, + class EngineBias, class LayoutBias> + static constexpr bool + compare_reference( + cute::Tensor const& reference, + cute::Tensor const& computed, + cute::Tensor const& A, + cute::Tensor const& B, + cute::Tensor const& tensor_alpha, + cute::Tensor const& tensor_beta, + cute::Tensor const& tensor_bias, + float epsilon = 0.0f) { + if (size(reference) != size(computed)) { + return false; + } + + bool passed = true; + if (epsilon == 0.0f) { + // fast refcheck w/o epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + if (reference(i) != computed(i)) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } else { + // refcheck with epsilon + for (size_t i = 0; i < size_t(size(reference)); ++i) { + auto ref = static_cast(reference(i)); + auto act = static_cast(computed(i)); + auto abs_error = std::abs(act - ref); + auto rel_error = abs_error / (std::max(std::abs(act), std::abs(ref)) + 0.00001f); + if (std::isnan(abs_error) || std::isnan(rel_error) || + std::min(abs_error, rel_error) > epsilon) { + passed = false; + printf("[%llu] %f, %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + break; + } + } + } + #if CUTLASS_DEBUG_TRACE_LEVEL > 1 + if (not passed) { + cute::print("Reference:"); + cute::print_tensor(reference); + cute::print("\nComputed:"); + cute::print_tensor(computed); + cute::print("\n"); + + for (size_t i = 0; i < size_t(size(A)); ++i) { + printf("[%llu]: A = %f\n", static_cast(i), float(A(i))); + } + for (size_t i = 0; i < size_t(size(B)); ++i) { + printf("[%llu]: B = %f\n", static_cast(i), float(B(i))); + } + if constexpr (IsPerChannelScaleEnabled) { + for (size_t i = 0; i < size_t(size(tensor_alpha)); ++i) { + printf("[%llu]: alpha = %f\n", static_cast(i), + float(tensor_alpha(i))); + } + for (size_t i = 0; i < size_t(size(tensor_beta)); ++i) { + printf("[%llu]: beta = %f\n", static_cast(i), + float(tensor_beta(i))); + } + } + if constexpr (IsBiasEnabled) { + for (size_t i = 0; i < size_t(size(tensor_bias)); ++i) { + printf("[%llu]: bias = %f\n", static_cast(i), + float(tensor_bias(i))); + } + } + for (size_t i = 0; i < size_t(size(reference)); ++i) { + printf("[%llu]: ref = %f, computed = %f\n", static_cast(i), + float(reference(i)), float(computed(i))); + } + } + #endif + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllConv(double alpha = 1.0, double beta = 0.0, float epsilon = 0.0f, + dim3 cluster_shape = dim3(0, 0, 0), + dim3 cluster_shape_fallback = dim3(0, 0, 0) + ) { + using ElementScalar = typename Conv::EpilogueOutputOp::ElementScalar; + + bool passed = true; + ConvTestbed testbed; + testbed.epsilon = epsilon; + auto problem_vector = get_conv_problem_vector< + Conv::NumSpatialDimensions, Conv::DispatchPolicy::ConvOp, SupportStrides>(); + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using MaxSwizzleSize = typename gemm::device::detail::MaxSwizzleSize; + using Splits = typename gemm::device::detail::Splits; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + + for (auto conv_problem : problem_vector) { + #if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print(conv_problem); + #endif + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(Splits{2}); + problem_splits.push_back(Splits{4}); + } + } + for (auto splits : problem_splits) { + + passed = testbed.run( + conv_problem, + cutlass::from_real(alpha), + cutlass::from_real(beta), + cluster_shape, + cluster_shape_fallback, + RasterOrderOptions::Heuristic, // raster_order + MaxSwizzleSize(1), + splits, + decomp_mode + ); + if (!passed) { + printf("Failed test for "); print(conv_problem); + return false; + } + } // splits + } // decomposition_mode + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace test::conv::device + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff170be142ff9d0d02cc684c2873c3ec014bd236 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/ampere/tiled_cp_async_testbed.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass_unit_test.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +using namespace cute; + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +template +__global__ void +test_tiled_cp_async_device_cute(T const* g_in, T* g_out, + TiledCopy const tiled_copy, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + auto thr_copy = tiled_copy.get_slice(threadIdx.x); + Tensor gA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor gB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); + + auto tAgA = thr_copy.partition_S(gA); + auto tAsA = thr_copy.partition_D(sA); + +#if 0 + if (thread0()) { + print("gA : "); print(gA.layout()); print("\n"); + print("sA : "); print(sA.layout()); print("\n"); + print("tAgA: "); print(tAgA.layout()); print("\n"); + print("tAsA: "); print(tAsA.layout()); print("\n"); + } +#endif + + copy(tiled_copy, tAgA, tAsA); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Store trivially smem -> gmem + + if (thread0()) { + copy(sA, gB); + } + +} + +template +void +test_tiled_cp_async( + TiledCopy const tiled_copy, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), T(-1)); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + test_tiled_cp_async_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tiled_copy, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < size(hA_out) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +template +void test_cp_async_no_swizzle() { + using namespace cute; + auto smem_atom = SMEM_LAYOUT{}; + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} + +template +void test_cp_async_with_swizzle() { + using namespace cute; + auto swizzle_atom = SWIZZLE_ATOM{}; + auto smem_atom = composition(swizzle_atom, SMEM_LAYOUT{}); + auto smem_layout = tile_to_shape(smem_atom, Shape{}); + auto gmem_layout = make_layout(make_shape(M{}, N{}), GMEM_STRIDE_TYPE{}); + test_tiled_cp_async(TILED_COPY{}, gmem_layout, smem_layout); +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ff20d4087ee2fd6f4f74338e3e63eef27c221d3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/cooperative_gemm_common.hpp @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/relatively_equal.h" +#include "cutlass_unit_test.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +#include + +#include +#include + +#include + +using namespace cute; + +template +struct fp64_tester { + using value_type = double; +}; + +template +struct fp64_tester> { + using value_type = complex; +}; + +template // logical shape (M, N) +auto host_generate_gemm_inputs( + ALayout a_layout, + BLayout b_layout, + CLayout c_layout +) { + thrust::host_vector h_a(cosize(a_layout)); + thrust::host_vector h_b(cosize(b_layout)); + thrust::host_vector h_c(cosize(c_layout)); + thrust::host_vector h_c_out(cosize(c_layout)); + + auto h_a_tensor = make_tensor(h_a.data(), a_layout); + auto h_b_tensor = make_tensor(h_b.data(), b_layout); + auto h_c_tensor = make_tensor(h_c.data(), c_layout); + size_t max_size = std::max({static_cast(size(a_layout)), + static_cast(size(b_layout)), + static_cast(size(c_layout))}); + for (size_t i = 0; i < max_size; ++i) { + double di = static_cast(i); + if(i < size(a_layout)) { + h_a_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(b_layout)) { + h_b_tensor(i) = static_cast(di / size(a_layout)); + } + if(i < size(c_layout)) { + h_c_tensor(i) = static_cast((di*di) / size(a_layout)); + } + } + + return std::make_tuple(h_a, h_b, h_c, h_c_out); +} + +template +thrust::host_vector +host_reference_gemm(Alpha alpha, + Tensor const& h_a_tensor, + Tensor const& h_b_tensor, + Beta beta, + Tensor const& h_c_tensor, + ALoadTransform const& a_load_transform = {}, + BLoadTransform const& b_load_transform = {}, + CLoadTransform const& c_load_transform = {}, + CStoreTransform const& c_store_transform = {}) + { + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TA = remove_cv_t; + using TB = remove_cv_t; + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + thrust::host_vector h_c_ref(cosize(h_c_tensor.layout()), static_cast(0.0)); + auto h_c_ref_tensor = make_tensor(h_c_ref.data(), h_c_tensor.layout()); + // A * B + for (int k = 0; k < size<1>(h_a_tensor); k++) { + for (int m = 0; m < size<0>(h_a_tensor); m++) { + for (int n = 0; n < size<0>(h_b_tensor); n++) { + const auto a_value = a_load_transform(h_a_tensor(m, k)); + const auto b_value = b_load_transform(h_b_tensor(n, k)); + const auto a_value_fp64 = static_cast(a_value); + const auto b_value_fp64 = static_cast(b_value); + h_c_ref_tensor(m, n) += static_cast(a_value_fp64 * b_value_fp64); + } + } + } + // C = A*B + C + for (int i = 0; i < size(h_c_ref_tensor); i++) { + const auto ab_value_fp64 = static_cast(h_c_ref_tensor(i)); + const auto c_value_fp64 = static_cast(c_load_transform(h_c_tensor(i))); + h_c_ref_tensor(i) = c_store_transform(static_cast(alpha * ab_value_fp64 + beta * c_value_fp64)); + } + + return h_c_ref; +} + +template +void verify_gemm_correctness(cute::Tensor const& h_c_out_tensor, + cute::Tensor const& h_c_ref_tensor) +{ + // Cannot use ::value_type because it propagates to complex::value_type, + // so ViewEngine>::value_type == double + using TC = remove_cv_t; + + using tester = fp64_tester; + using ABC_64 = typename tester::value_type; + + for (int i = 0; i < size(h_c_ref_tensor); i++) { + ABC_64 h_c_ref_i = h_c_ref_tensor(i); + ABC_64 h_c_out_i = h_c_out_tensor(i); + double epsilon(0.1f); + double nonzero_floor(std::numeric_limits::min()); + bool passed = cutlass::relatively_equal(h_c_out_i, h_c_ref_i, epsilon, nonzero_floor); + ASSERT_TRUE(passed) << i << " - result:" << h_c_out_i << " expected:" << h_c_ref_i; + } +} + + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + Alpha const alpha, + Beta const beta, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op, + SMemCopyLdOpC c_copy_ld_op, + SMemCopyStOpC c_copy_st_op) +{ + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + auto* smem_ptr_c = smem_ptr_b + round_up((sizeof(TB) * cosize(smem_b_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + Tensor s_c_tensor = make_tensor(make_smem_ptr(smem_ptr_c), smem_c_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + cooperative_copy(threadIdx.x, g_c_tensor, s_c_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + cooperative_gemm( + threadIdx.x, tiled_mma, + alpha, s_a_tensor, s_b_tensor, beta, s_c_tensor, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_copy_op, b_copy_op, c_copy_ld_op, c_copy_st_op + ); + __syncthreads(); + + cooperative_copy(threadIdx.x, s_c_tensor, g_c_out_tensor); +} + +template +__launch_bounds__(ThreadBlockSize) __global__ void +cooperative_gemm_kernel_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TA const* a, + TB const* b, + TC const* c, + TC * c_out, + TiledMma tiled_mma, + ALoadTransform a_load_transform, + BLoadTransform b_load_transform, + CLoadTransform c_load_transform, + CStoreTransform c_store_transform, + SMemCopyOpA a_copy_op, + SMemCopyOpB b_copy_op) + { + using namespace cute; + + Tensor g_a_tensor = make_tensor(make_gmem_ptr(a), gmem_a_layout); + Tensor g_b_tensor = make_tensor(make_gmem_ptr(b), gmem_b_layout); + Tensor g_c_tensor = make_tensor(make_gmem_ptr(c), gmem_c_layout); + Tensor g_c_out_tensor = make_tensor(make_gmem_ptr(c_out), gmem_c_layout); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + extern __shared__ float4 smem_buf[]; + auto* smem_ptr = reinterpret_cast(smem_buf); + auto* smem_ptr_a = smem_ptr; + auto* smem_ptr_b = smem_ptr_a + round_up((sizeof(TA) * cosize(smem_a_layout)), copy_max_vec_bytes); + + Tensor s_a_tensor = make_tensor(make_smem_ptr(smem_ptr_a), smem_a_layout); + Tensor s_b_tensor = make_tensor(make_smem_ptr(smem_ptr_b), smem_b_layout); + + cooperative_copy(threadIdx.x, g_a_tensor, s_a_tensor); + cooperative_copy(threadIdx.x, g_b_tensor, s_b_tensor); + + cp_async_fence(); + cp_async_wait<0>(); + __syncthreads(); + + // Create C fragment for storing intermediate results + auto thr_mma = TiledMma().get_thread_slice(threadIdx.x); + Tensor g_c_partition = thr_mma.partition_C(g_c_tensor); + Tensor g_c_out_partition = thr_mma.partition_C(g_c_out_tensor); + Tensor r_c_partition = thr_mma.make_fragment_C(g_c_partition); + + // Create indexing help for predicated GEMMs + Tensor cC = make_identity_tensor(shape(gmem_c_layout)); + Tensor tCcC = thr_mma.partition_C(cC); + + // Load C from global + // (always loading in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + r_c_partition(i) = c_load_transform(g_c_partition(i)); + } + } + + cooperative_gemm( + threadIdx.x, tiled_mma, s_a_tensor, s_b_tensor, r_c_partition, + a_load_transform, b_load_transform, a_copy_op, b_copy_op + ); + + __syncthreads(); + + // Store C to global + // (always storing in predicated way) + CUTE_UNROLL + for (int i = 0; i < size(r_c_partition); ++i) + { + if (elem_less(tCcC(i), shape(g_c_tensor))) + { + g_c_out_partition(i) = c_store_transform(r_c_partition(i)); + } + } +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyLdOp = AutoVectorizingCopyWithAssumedAlignment, + class CSMemCopyStOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + SMemCLayout smem_c_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}, + CSMemCopyLdOp c_smem_copy_ld_op = {}, + CSMemCopyStOp c_smem_copy_st_op = {}) +{ + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + static_assert(std::is_same_v::value_type, typename fp64_tester::value_type>); + + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<0>(smem_a_layout) == size<0>(smem_c_layout)); // AM == CM + static_assert(size<0>(smem_b_layout) == size<1>(smem_c_layout)); // BN == CN + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + static_assert(cute::size(gmem_c_layout) == cute::size(smem_c_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.1); + const auto beta = static_cast(1.2); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), TC(float(-1))); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes) + + sizeof(TC) * h_c.size(); + + + auto kernel = cooperative_gemm_kernel< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, SMemCLayout, + TA, TB, TC, decltype(alpha), decltype(beta), + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp, CSMemCopyLdOp, CSMemCopyStOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + alpha, + beta, + tiled_mma, + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform, + a_smem_copy_op, + b_smem_copy_op, + c_smem_copy_ld_op, + c_smem_copy_st_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Copy result data + h_c_out = d_c_out; + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template, + class BSMemCopyOp = AutoVectorizingCopyWithAssumedAlignment> +void test_cooperative_gemm_rmem_c(GMemALayout gmem_a_layout, + GMemBLayout gmem_b_layout, + GMemCLayout gmem_c_layout, + SMemALayout smem_a_layout, + SMemBLayout smem_b_layout, + TiledMma tiled_mma, + ALoadTransform a_load_transform = {}, + BLoadTransform b_load_transform = {}, + CLoadTransform c_load_transform = {}, + CStoreTransform c_store_transform = {}, + ASMemCopyOp a_smem_copy_op = {}, + BSMemCopyOp b_smem_copy_op = {}) +{ + static_assert(size<0>(gmem_a_layout) == size<0>(gmem_c_layout)); // AM == CM + static_assert(size<0>(gmem_b_layout) == size<1>(gmem_c_layout)); // BN == CN + static_assert(size<1>(gmem_a_layout) == size<1>(gmem_b_layout)); // AK == BK + + static_assert(size<1>(smem_a_layout) == size<1>(smem_b_layout)); // AK == BK + + static_assert(cute::size(gmem_a_layout) == cute::size(smem_a_layout)); + static_assert(cute::size(gmem_b_layout) == cute::size(smem_b_layout)); + +#if 0 + print(" "); print("gmem: "); print(gmem_layout); print("\n"); + print(" "); print("smem: "); print(smem_layout); print("\n"); + print(" "); print("threads: "); print(ThreadBlockSize); print("\n"); +#endif + + const auto alpha = static_cast(1.0); + const auto beta = static_cast(1.0); + + // Generate inputs + auto [h_a, h_b, h_c, h_c_out] = + host_generate_gemm_inputs(gmem_a_layout, gmem_b_layout, gmem_c_layout); + + thrust::device_vector d_a(h_a); + thrust::device_vector d_b(h_b); + thrust::device_vector d_c(h_c); + thrust::device_vector d_c_out(h_c_out.size(), static_cast(-1)); + + constexpr uint32_t copy_max_vec_bytes = CopyMaxVecBits / 8; + + const size_t shared_memory_size = round_up(sizeof(TA) * h_a.size(), copy_max_vec_bytes) + + round_up(sizeof(TB) * h_b.size(), copy_max_vec_bytes); + + + auto kernel = cooperative_gemm_kernel_rmem_c< + ThreadBlockSize, CopyMaxVecBits, + GMemALayout, GMemBLayout, GMemCLayout, + SMemALayout, SMemBLayout, + TA, TB, TC, + TiledMma, + ALoadTransform, BLoadTransform, CLoadTransform, CStoreTransform, + ASMemCopyOp, BSMemCopyOp + >; + + ASSERT_EQ(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, static_cast(shared_memory_size)), 0); + + kernel<<<1, ThreadBlockSize, shared_memory_size>>>( + gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + thrust::raw_pointer_cast(d_a.data()), + thrust::raw_pointer_cast(d_b.data()), + thrust::raw_pointer_cast(d_c.data()), + thrust::raw_pointer_cast(d_c_out.data()), + tiled_mma, + a_load_transform, b_load_transform, c_load_transform, c_store_transform, + a_smem_copy_op, b_smem_copy_op + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + cudaError_t error = cudaGetLastError(); + FAIL() << "Error at kernel sync: " << cudaGetErrorString(error) << "\n"; + } + + // Copy result data + h_c_out = d_c_out; + + // Reference gemm + auto h_c_ref = host_reference_gemm(alpha, + make_tensor(h_a.data(), gmem_a_layout), + make_tensor(h_b.data(), gmem_b_layout), + beta, + make_tensor(h_c.data(), gmem_c_layout), + a_load_transform, + b_load_transform, + c_load_transform, + c_store_transform); + + // Verify correctness + verify_gemm_correctness(make_tensor(h_c_out.data(), gmem_c_layout), + make_tensor(h_c_ref.data(), gmem_c_layout)); +} + +template +void test_cooperative_gemm_col_major_layout(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + test_cooperative_gemm + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + c_layout, + tiled_mma, + ops...); +} + + +template +std::enable_if_t, + cute::is_layout, + cute::is_layout>> +test_cooperative_gemm_col_major_layout(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + SMemAtomLayoutC smem_atom_layout_c, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops&& ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + auto smem_c_layout = tile_to_shape( + smem_atom_layout_c, + make_shape(shape<0>(gmem_c_layout), shape<1>(gmem_c_layout))); + + test_cooperative_gemm + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + smem_c_layout, + tiled_mma, + ops...); +} + + +template +void test_cooperative_gemm_col_major_layout_rmem_c(ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto a_layout = make_layout(select<0, 2>(shape_mnk)); + auto b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto c_layout = make_layout(select<0, 1>(shape_mnk)); + + + test_cooperative_gemm_rmem_c + (a_layout, + b_layout, + c_layout, + a_layout, + b_layout, + tiled_mma, + ops...); +} + +template +std::enable_if_t, + cute::is_layout>> +test_cooperative_gemm_col_major_layout_rmem_c(SMemAtomLayoutA smem_atom_layout_a, + SMemAtomLayoutB smem_atom_layout_b, + ShapeMNK shape_mnk, + TiledMma tiled_mma, + Ops ... ops) +{ + auto gmem_a_layout = make_layout(select<0, 2>(shape_mnk)); + auto gmem_b_layout = make_layout(select<1, 2>(shape_mnk), GenRowMajor{}); + auto gmem_c_layout = make_layout(select<0, 1>(shape_mnk)); + + auto smem_a_layout = tile_to_shape( + smem_atom_layout_a, + make_shape(shape<0>(gmem_a_layout), shape<1>(gmem_a_layout))); + + auto smem_b_layout = tile_to_shape( + smem_atom_layout_b, + make_shape(shape<0>(gmem_b_layout), shape<1>(gmem_b_layout))); + + test_cooperative_gemm_rmem_c + (gmem_a_layout, + gmem_b_layout, + gmem_c_layout, + smem_a_layout, + smem_b_layout, + tiled_mma, + ops...); +} + +template +void test_cooperative_gemm_col_major_layout_rmem_c(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout_rmem_c, + T, T, T> + (static_cast(args)...); +} + +template +void test_cooperative_gemm_col_major_layout(Args&& ... args) +{ + test_cooperative_gemm_col_major_layout, + T, T, T> + (static_cast(args)...); +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d2620e62ff247e36ae49809ab4ef3416560ae31 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_load_testbed.hpp @@ -0,0 +1,217 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_LOAD + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA_x: "); print(tAgA_x); print("\n"); + print("tAsA_x: "); print(tAsA_x); print("\n"); + } +#endif + + // + // Perform the TMA_LOAD + // + + // INPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tAgA = group_modes<1,rank(tAgA_x)>(tAgA_x); // (TMA,REST) + Tensor tAsA = group_modes<1,rank(tAsA_x)>(tAsA_x); // (TMA,REST) + static_assert(size<1>(tAsA) == 1); + + // OUTPUT: Group the CTA_TILE_X modes and REST_X modes for output + Tensor tBgB = group_modes<0,R>(group_modes(gB)); // (CTA_TILE, REST) + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + } +#endif + + // Test L2 prefetch + if (threadIdx.x == 0) { + prefetch(tma, tAgA); + } + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(make_tensor_like(tensor<0>(tAsA))); + + if (threadIdx.x == 0) + { + /// Initialize shared memory barrier + tma_load_mbar[0] = 0; + cute::initialize_barrier(tma_load_mbar[0], 1 /*numThreads*/); + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0]), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from kPhaseBit value + constexpr int kPhaseBit = 0; + cute::wait_barrier(tma_load_mbar[0], kPhaseBit); + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3e0ec46df1b672c35c3c38f731c09b0134d4cd80 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_mcast_load_testbed.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include +#include +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; + alignas(16) cute::uint64_t tma_load_mbar[1]; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, GmemLayout gmem_layout, SmemLayout smem_layout, + CUTE_GRID_CONSTANT CopyAtom const tma, CTA_Tiler cta_tiler, Cluster_Size cluster_size) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + // Shared memory barriers use 64bits in SMEM for synchronization + uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + + Tensor gA = zipped_divide(mA, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + Tensor gB = zipped_divide(mB, cta_tiler); // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + +#if 1 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Prepare the TMA_LOAD + // + + Tensor sA_x = make_tensor(sA.data(), make_layout(sA.layout(), Layout<_1>{})); // ((CTA_TILE_M,CTA_TILE_N,...),_1) + Tensor tBgB = gB; // ((CTA_TILE_M,CTA_TILE_N,...),(REST_M,REST_N,...)) + + int cta_rank_in_cluster = cute::block_rank_in_cluster(); + auto [tAgA, tAsA] = tma_partition(tma, cta_rank_in_cluster, make_layout(cluster_size), sA_x, gA); + +#if 1 + if (thread0()) { + print("sA_x : "); print(sA_x); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // TMA Multicast Masks -- Get a mask of the active ctas in each TMA + // + + + int elected_cta_rank = 0; + bool elect_one_cta = (elected_cta_rank == cta_rank_in_cluster); + bool elect_one_thr = cute::elect_one_sync(); + + uint16_t tma_mcast_mask = ((uint16_t(1) << cluster_size) - 1); + +#if 1 + if (thread0()) { + print("tma_mcast_mask : "); print(tma_mcast_mask); print("\n"); + } __syncthreads(); cute::cluster_sync(); +#endif + + // + // Perform the TMA_LOAD + // + + if (elect_one_thr) { + // Initialize TMA barrier + cute::initialize_barrier(tma_load_mbar[0], /* num_threads */ 1); + } + int tma_phase_bit = 0; + // Ensures all CTAs in the Cluster have initialized + __syncthreads(); + cute::cluster_sync(); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tAgA); ++stage) + { + // Set the bytes transferred in this TMA transaction (may involve multiple issues) + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); + + if (elect_one_thr) + { + cute::set_barrier_transaction_bytes(tma_load_mbar[0], kTmaTransactionBytes); + + copy(tma.with(tma_load_mbar[0], tma_mcast_mask), tAgA(_,stage), tAsA(_,0)); + } + __syncthreads(); + + /// Wait on the shared memory barrier until the phase bit flips from tma_phase_bit value + cute::wait_barrier(tma_load_mbar[0], tma_phase_bit); + tma_phase_bit ^= 1; + + // + // Write out trivially smem -> gmem + // + + // Subbyte elements could cause race conditions, so be even more conservative + if (elect_one_cta && elect_one_thr) { + copy(sA, tBgB(_,stage)); + } + + __syncthreads(); + cute::cluster_sync(); + } +} + +template +auto +test_tma_load(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); + auto tma = make_tma_atom(copy_op, gA, smem_layout, cta_tiler, cluster_size); + //print(tma); + + // Launch + + dim3 dimBlock(32); + dim3 dimCluster(size(cluster_size)); + dim3 dimGrid = dimCluster; + int smem_size = sizeof(SharedStorage); + + void* kernel_ptr = (void*) &tma_test_device_cute; + + cutlass::launch_kernel_on_cluster({dimGrid, dimBlock, dimCluster, smem_size}, + kernel_ptr, + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast(raw_pointer_cast(d_out.data())), + gmem_layout, + smem_layout, + tma, cta_tiler, cluster_size); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } + + return tma; +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0429d2435fbf43c690f311c1f7c04f7025a2dd94 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/cute/hopper/tma_store_testbed.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass_unit_test.h" + +#include +#include + +#include +#include + +#include + +namespace cutlass::test { + +template +struct SharedStorage +{ + cute::ArrayEngine> smem; +}; + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + +template +__global__ void +tma_test_device_cute(T const* g_in, T* g_out, + CUTE_GRID_CONSTANT TiledCopy const tma, CTA_Tiler cta_tiler, + GmemLayout gmem_layout, SmemLayout smem_layout) +{ + using namespace cute; + CUTE_STATIC_ASSERT_V(product_each(shape(cta_tiler)) == product_each(shape(smem_layout))); + + // Use Shared Storage structure to allocate and distribute aligned SMEM addresses + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + + // Construct SMEM tensor + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); + + constexpr int R = rank_v; + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + + // + // Prepare the TMA_STORE + // + + auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice + Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) + Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) + +#if 0 + if (thread0()) { + print(tma); + print("TILE : "); print(cta_tiler); print("\n"); + print(" mB : "); print( mB.data()); print(" o "); print( mB.layout()); print("\n"); + print(" gB : "); print( gB.data()); print(" o "); print( gB.layout()); print("\n"); + print("tBgB_x: "); print(tBgB_x.data()); print(" o "); print(tBgB_x.layout()); print("\n"); + print(" sB : "); print( sB.data()); print(" o "); print( sB.layout()); print("\n"); + print("tBsB_x: "); print(tBsB_x.data()); print(" o "); print(tBsB_x.layout()); print("\n"); + } +#endif + + // + // Perform the TMA_STORE + // + + // INPUT: Group the CTA_TILE_X modes and REST_X modes for input + Tensor tAgA = group_modes<0,R>(group_modes(gA)); // (CTA_TILE, REST) + + // OUTPUT: Group the REST_X modes and the TMA_X modes to easily iterate through the tiles + Tensor tBgB = group_modes<1,rank(tBgB_x)>(tBgB_x); // (TMA,REST) + Tensor tBsB = group_modes<1,rank(tBsB_x)>(tBsB_x); // (TMA,REST) + static_assert(size<1>(tBsB) == 1); + +#if 0 + if (thread0()) { + print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); + print("tBsB : "); print(tBsB.data()); print(" o "); print(tBsB.layout()); print("\n"); + print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + } +#endif + + // Test L2 prefetch + cooperative_prefetch<128>(threadIdx.x, gA); + + // Loop over the TMA stages, using smem as our buffer + for (int stage = 0; stage < size<1>(tBgB); ++stage) + { + // + // Read in trivially gmem -> smem + // + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(tAgA(_,stage), sB); + } + + __syncthreads(); + cute::cp_async_wait<0>(); + + // + // Perform the TMA_STORE + // + + if (threadIdx.x == 0) { + copy(tma, tBsB(_,0), tBgB(_,stage)); + } + + tma_store_wait<0>(); + __syncthreads(); + } +} + +template +void +test_tma_store(CopyOp const& copy_op, + GMEM_Layout const& gmem_layout, + SMEM_Layout const& smem_layout, + CTA_Tile const& cta_tile) +{ + using namespace cute; + + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + for (size_t i = 0; i < h_in.size(); ++i) { + h_in[i] = uint8_t(i % 13); + } + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), uint8_t(-1)); // overflow uint + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); + auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); + //print(tma); + + // Launch + int smem_size = int(sizeof(SharedStorage)); + tma_test_device_cute<<<1, 128, smem_size>>>( + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), + tma, cta_tile, + gmem_layout, + smem_layout); + + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); + + // Validate the results. Print only the first 3 errors. + int count = 3; + for (int i = 0; i < int(size(hA_out)) && count > 0; ++i) { + EXPECT_EQ(hA_in(i), hA_out(i)); + if (hA_in(i) != hA_out(i)) { + --count; + } + } +} + +#endif + +} // end namespace cutlass::test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..3163a0d0eaa24513ee210bd2b310d1bf233773a9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/epilogue_with_reduction_testbed.h @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_with_reduction_threadblock( + typename Epilogue::ElementVector *ptr_Reduction, + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::TensorTileIterator::Params params_Tensor, + typename Epilogue::TensorTileIterator::Element *ptr_Tensor, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::TensorTileIterator iterator_T( + params_Tensor, + ptr_Tensor, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, ptr_Reduction, iterator_D, accumulators, iterator_C, iterator_T); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueWithReductionTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementTensor = typename Epilogue::TensorTileIterator::Element; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + cutlass::HostTensor additional_tensor; + cutlass::HostTensor reduction_tensor; + + +public: + + // + // Methods + // + + EpilogueWithReductionTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + additional_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + reduction_tensor({1, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + + cutlass::reference::host::TensorFill(additional_tensor.host_view(), ElementTensor(1)); + } + + bool run_all() { + + /* + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + return true; + */ + + double alpha = 1; + double beta = 0; + + return run( + {quantized_size.row(), quantized_size.column()}, + {cutlass::from_real(alpha), cutlass::from_real(beta)}); + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + ElementAccumulator default_reduction = ElementAccumulator(); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + cutlass::reference::host::TensorFill(reduction_tensor.host_view(), default_reduction); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + additional_tensor.sync_device(); + reduction_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + typename Epilogue::TensorTileIterator::Params params_T(additional_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_with_reduction_threadblock<<< grid, block >>>( + reduction_tensor.device_data(), + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + params_T, + additional_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + reduction_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + // + // The output has two parts: + // - GEMM tensor epilogue in canonical layout + // - partial reduction in canonical row-major layout + // + + // Verify the GEMM tensor output + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + expected = ElementOutput(output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord))); + } + else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // Verify the partial reduction + for (int c = 0; c < quantized_size.column(); ++c) { + + ElementAccumulator reduction_acc = ElementAccumulator(); + + for (int r = 0; r < quantized_size.row(); ++r) { + reduction_acc += accumulator_tensor.at({r, c}); + } + + ElementAccumulator expected = default_reduction; + ElementAccumulator got = reduction_tensor.at({0, c}); + + if (c < problem_size.column()) { + expected = reduction_acc; + } + else { + expected = default_reduction; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - reduction element (" << c << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e2457fdb4817e1dfb3af73149ae1e4c4458670a2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed.h @@ -0,0 +1,356 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/platform/platform.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + accumulator_iterator.load(accumulators); + +#if 0 + // For debugging, enable this block of code to fill each accumulator element with its + // source thread ID. + CUTLASS_PRAGMA_UNROLL + for (size_t i = 0; i < accumulators.size(); ++i) { + typename Epilogue::WarpMmaOperator::ElementC x(threadIdx.x); + accumulators[i] = x; + } + + __syncthreads(); + +#endif + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue(output_op, iterator_D, accumulators, iterator_C); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpilogueTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensor accumulator_tensor; + cutlass::HostTensor source_tensor; + cutlass::HostTensor output_tensor; + +public: + + // + // Methods + // + + EpilogueTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 2, + -2, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 2, + -2, + 0); + } + + bool run_all() { + + double alpha_values[] = {1, 0, 2.25}; + double beta_values[] = {0, 1, -1.25}; + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + int m = quantized_size.row() - m_idx * 3; + int n = quantized_size.column() - n_idx * Epilogue::kElementsPerAccess; + + for (double const &alpha : alpha_values) { + for (double const &beta : beta_values) { + + bool passed = run({m, n}, {cutlass::from_real(alpha), cutlass::from_real(beta)}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ElementOutput default_output = ElementOutput(-127); + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.device_ref().layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.device_ref().layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + params_C, + source_tensor.device_data(), + output_params, + problem_size, + accumulator_tensor.device_view()); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ElementOutput got = output_tensor.at(coord); + + ElementOutput expected; + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + ElementCompute intermediate = + output_params.alpha * ElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ElementCompute(source_tensor.at(coord)); + + if ((cutlass::platform::is_same::value + || cutlass::platform::is_same::value + || std::numeric_limits::is_integer) + && !std::numeric_limits::is_integer) { + std::fesetround(FE_TONEAREST); + expected = ElementOutput(std::nearbyint(float(cutlass::real(intermediate)))); + } else { + expected = ElementOutput(intermediate); + } + } else { + expected = default_output; + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) + << ", accum: " << (accumulator_tensor.at(coord)) + << ", source: " << OutputIO(source_tensor.at(coord)) + << ", alpha: " << (output_params.alpha) + << ", beta: " << (output_params.beta) << "\n"; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a76578f7638ac1d30161a9bcb55ecec70b5c43e0 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/epilogue/threadblock/testbed_planar_complex.h @@ -0,0 +1,394 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for epilogues +*/ +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/half.h" +#include "cutlass/complex.h" + +#include "cutlass/epilogue/thread/linear_combination_planar_complex.h" + +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace kernel { + +template +__global__ void epilogue_planar_complex_threadblock( + typename Epilogue::OutputTileIterator::Params params_D, + typename Epilogue::OutputTileIterator::Element *ptr_D, + int64_t imaginary_stride_D, + typename Epilogue::OutputTileIterator::Params params_C, + typename Epilogue::OutputTileIterator::Element *ptr_C, + int64_t imaginary_stride_C, + typename Epilogue::OutputOp::Params params_output_op, + cutlass::MatrixCoord problem_size, + cutlass::TensorRef< + typename Epilogue::WarpMmaOperator::ElementC, + typename Epilogue::WarpMmaOperator::LayoutC> accumulator_ref, + int64_t imaginary_stride_accum, + int epilogue_count = 1) { + + __shared__ typename Epilogue::SharedStorage shared_storage; + + int thread_idx = threadIdx.x; + int warp_idx = threadIdx.x / 32; + int lane_idx = threadIdx.x % 32; + + // + // Construct the epilogue + // + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_D_real( + params_D, + ptr_D, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_D_imag( + params_D, + ptr_D + imaginary_stride_D, + problem_size, + thread_idx + ); + + // Tile iterator writing to output tile + typename Epilogue::OutputTileIterator iterator_C_real( + params_C, + ptr_C, + problem_size, + thread_idx + ); + + typename Epilogue::OutputTileIterator iterator_C_imag( + params_C, + ptr_C + imaginary_stride_C, + problem_size, + thread_idx + ); + + // Epilogue operator + Epilogue epilogue( + shared_storage, + thread_idx, + warp_idx, + lane_idx); + + // + // Initialize the accumulators + // + + int warp_mn = warp_idx % (Epilogue::WarpCount::kM * Epilogue::WarpCount::kN); + int warp_m = warp_mn % Epilogue::WarpCount::kM; + int warp_n = warp_mn / Epilogue::WarpCount::kM; + + accumulator_ref.add_coord_offset({ + warp_m * Epilogue::WarpMmaOperator::Shape::kM, + warp_n * Epilogue::WarpMmaOperator::Shape::kN}); + + // + // Load accumulators + // + + typename Epilogue::WarpMmaOperator::IteratorC accumulator_iterator(accumulator_ref, lane_idx); + + typename Epilogue::AccumulatorTile accumulators; + + accumulators.clear(); + + accumulator_iterator.load(accumulators.real); + accumulator_iterator.load_with_pointer_offset(accumulators.imag, imaginary_stride_accum); + + // + // Perform the epilogue operation + // + + typename Epilogue::OutputOp output_op(params_output_op); + + // Place the epilogue in a loop so assembly is clearly visible + for (int iter = 0; iter < epilogue_count; ++iter) { + epilogue( + output_op, + iterator_D_real, + iterator_D_imag, + accumulators, + iterator_C_real, + iterator_C_imag); + } +} + +} // namespace kernel +} // namespace test + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Epilogue_ +> +class EpiloguePlanarComplexTestbed { +public: + + using Epilogue = Epilogue_; + using ElementAccumulator = typename Epilogue::ElementAccumulator; + using ElementCompute = typename Epilogue::OutputOp::ElementCompute; + using ElementOutput = typename Epilogue::ElementOutput; + using OutputOpParams = typename Epilogue::OutputOp::Params; + + using ComplexElementOutput = cutlass::complex; + using ComplexElementAccumulator = cutlass::complex; + using ComplexElementCompute = cutlass::complex; + +public: + + // + // Data members + // + + cutlass::MatrixCoord quantized_size; + cutlass::HostTensorPlanarComplex accumulator_tensor; + cutlass::HostTensorPlanarComplex source_tensor; + cutlass::HostTensorPlanarComplex output_tensor; + +public: + + // + // Methods + // + + EpiloguePlanarComplexTestbed(): + quantized_size(Epilogue::Shape::kM, Epilogue::Shape::kN), + accumulator_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + source_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}), + output_tensor({Epilogue::Shape::kM, Epilogue::Shape::kN}) { + + // + // Initialize problem space + // + + #if 1 + uint64_t seed = 2019; + + cutlass::reference::host::TensorFillRandomUniform( + accumulator_tensor.host_view(), + seed, + 20, + -20, + 0); + + cutlass::reference::host::TensorFillRandomUniform( + source_tensor.host_view(), + seed + 2018, + 20, + -20, + 0); + #else + + cutlass::reference::host::BlockFillSequential(accumulator_tensor.host_data(), accumulator_tensor.capacity()); + + #endif + } + + bool run_all() { + + cutlass::complex alpha_values[3]; + + alpha_values[0] = cutlass::complex(1, 0); + alpha_values[1] = cutlass::complex(0, 0); + alpha_values[2] = cutlass::complex(2.25f, -0.5f); + + cutlass::complex beta_values[3]; + + beta_values[0] = cutlass::complex(0, 0); + beta_values[1] = cutlass::complex(1, 0); + beta_values[2] = cutlass::complex(0.5f, -2.25f); + + // Test runtime explodes if we tried to test every case exhaustively. This tests the full + // output tile and several smaller sizes to stress predication. + for (int m_idx = 0; m_idx < 3; ++m_idx) { + for (int n_idx = 0; n_idx < 3; ++n_idx) { + + cutlass::MatrixCoord problem_size( + quantized_size.row() - m_idx * 3, + quantized_size.column() - n_idx * Epilogue::kElementsPerAccess + ); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = run(problem_size, {alpha, beta}); + + if (!passed) { + return false; + } + } + } + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::MatrixCoord problem_size, + OutputOpParams output_params) { + + // + // Initialize problem space + // + + ComplexElementOutput default_output = ComplexElementOutput(ElementOutput(-127), ElementOutput(-101)); + + cutlass::reference::host::TensorFill(output_tensor.host_view(), default_output); + + accumulator_tensor.sync_device(); + output_tensor.sync_device(); + source_tensor.sync_device(); + + // + // Initialize epilogue parameters + // + + typename Epilogue::OutputTileIterator::Params params_D(output_tensor.layout()); + typename Epilogue::OutputTileIterator::Params params_C(source_tensor.layout()); + + // + // Launch kernel + // + + dim3 grid(1, 1); + dim3 block(Epilogue::WarpCount::kCount * 32, 1); + + test::kernel::epilogue_planar_complex_threadblock<<< grid, block >>>( + params_D, + output_tensor.device_data(), + output_tensor.imaginary_stride(), + params_C, + source_tensor.device_data(), + source_tensor.imaginary_stride(), + output_params, + problem_size, + accumulator_tensor.device_view_real(), + accumulator_tensor.imaginary_stride() + ); + + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "Kernel error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + // + // Verify results + // + output_tensor.sync_host(); + + int errors = 0; + int const kMaxErrors = 5; + + for (int r = 0; errors < kMaxErrors && r < quantized_size.row(); ++r) { + for (int c = 0; errors < kMaxErrors && c < quantized_size.column(); ++c) { + + cutlass::MatrixCoord coord{r, c}; + ComplexElementOutput got = output_tensor.at(coord); + + ComplexElementOutput expected = default_output; + + if (coord.row() < problem_size.row() && coord.column() < problem_size.column()) { + + ComplexElementOutput src = source_tensor.at(coord); + + ComplexElementCompute tmp = + output_params.alpha * ComplexElementCompute(accumulator_tensor.at(coord)) + + output_params.beta * ComplexElementCompute(src.real(), src.imag()); + + expected = ComplexElementOutput(ElementOutput(tmp.real()), ElementOutput(tmp.imag())); + } + + if (expected != got) { + + using OutputIO = cutlass::ScalarIO; + + EXPECT_TRUE(false) + << "-------\n" + << "Error - output element (" << coord << ") - expected: " + << OutputIO(expected) + << ", got: " << OutputIO(got) << std::endl; + + ++errors; + } + } + } + + // + // Report results on error + // + + if (errors) { + + + std::cout << "Incorrect result for problem(" + << problem_size.row() << ", " + << problem_size.column() << ") for alpha: " << output_params.alpha << ", beta: " << output_params.beta << std::endl; + + std::stringstream ss; + ss + << "output_tensor_op_" << Epilogue::Shape::kM << "x" << Epilogue::Shape::kN << "_" + << Epilogue::WarpTileIterator::WarpShape::kM << "x" + << Epilogue::WarpTileIterator::WarpShape::kN + << "_slice_" << Epilogue::WarpCount::kK << ".csv"; + + std::ofstream output_file(ss.str()); + output_file << output_tensor.host_view(); + + std::cout << "Wrote workspace to '" << ss.str() << "'" << std::endl; + } + + return !errors; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0054a1b6757a232e9177407fdd2041b6a91cffb9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/default_gemm_configuration.hpp @@ -0,0 +1,1384 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/layout.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace cutlass { +namespace gemm { +namespace device { +using namespace cute; + +// This type is only intended to demonstrate porting 2.x kernels to 3.0 +template< + class OperatorClass, class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types { + static_assert(sizeof(ElementA) == 0, "No valid DefaultGemmConfigurationToCutlass3Types configuration exists."); +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_TensorOpSm80_OperandA; + +template +struct DefaultGemm_TensorOpSm80_OperandB; + +// +// F16: 128-by-128-by-64 +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride<_64, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,3,3>{}, + Layout, + Stride< _1,_64>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the F32F16 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// +// F16: 128-by-128-by-32 (small k-block) +// + +/// Operand A - Row-major (K-Major) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,3,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, half_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32F16 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + half_t, LayoutA, + half_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_16>>; // 32x32x16 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 8; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + half_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 8; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + half_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + half_t, TagToStrideA_t, + half_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// +// TF32: 128-by-128-by-kblock (kBlock = 16, 32) +// + +/// Operand A - Row-major (K-major) (kBlock = 32) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride<_32, _1>>{})); + using SmemCopyAtom = Copy_Atom; + + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); +}; + +/// Operand A - Row-major (K-major) (kBlock = 16) +template <> +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<2,2,3>{}, + Layout, + Stride<_16, _1>>{})); + using SmemCopyAtom = Copy_Atom; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>{})); +}; + +/// Operand A - Column-major (M-major) +template +struct DefaultGemm_TensorOpSm80_OperandA +{ + // Smem + using SmemLayoutAtom = decltype( + composition(Swizzle<3,2,3>{}, + Layout, + Stride< _1,_32>>{})); + using SmemCopyAtom = Copy_Atom, tfloat32_t>; + // Gmem + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, tfloat32_t>{}, + Layout, + Stride< _1,_16>>{}, + Layout>{})); +}; + +// Because the TF32 TiledMMA is A-B symmetric, we can reuse the DefaultOperands + +// Operand B - Column-Major (K-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +// Operand B - Row-Major (N-major) +template +struct DefaultGemm_TensorOpSm80_OperandB + : DefaultGemm_TensorOpSm80_OperandA +{}; + +} + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere MMA F32TF32 +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + tfloat32_t, LayoutA, + tfloat32_t, LayoutB, + float, LayoutC, + float> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout, Stride<_2, _1, _1>>, // 2x2x1 thread group + Tile<_32,_32,_8>>; // 32x32x8 MMA for LDSM, 1x2x1 value group + + // A + static constexpr int kAlignmentA = 4; + using DefaultOperandA = detail::DefaultGemm_TensorOpSm80_OperandA< + tfloat32_t, LayoutA, kAlignmentA, 32>; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; // M, K + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 4; + using DefaultOperandB = detail::DefaultGemm_TensorOpSm80_OperandB< + tfloat32_t, LayoutB, kAlignmentB, 32>; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; // N, K + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + tfloat32_t, TagToStrideA_t, + tfloat32_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + float, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +template +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + int32_t, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _64>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, // 2x2x1 thread group + Tile<_32,_32,_32>>; // 16x16x32 MMA for LDSM, 1x2x1 value group + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentA = 16; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomA = Copy_Atom; + using SmemCopyAtomA = Copy_Atom; // LDSM works + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + composition( + Swizzle<2,4,3>{}, + Layout, + Stride<_64, _1>>{})); + static constexpr int kAlignmentB = 16; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, int8_t>{}, + Layout, + Stride< _4,_1>>{}, + Layout>>{})); + + // LDS.32- or LDSM-based copy atom + // using SmemCopyAtomB = Copy_Atom; + using SmemCopyAtomB = Copy_Atom; // LDSM works + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + int8_t, TagToStrideA_t, + int8_t, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + int32_t, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// SIMT TWO STAGE /////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct DefaultGemm_Simt_OperandA; + +/////////////////////////////////////////////////////////////////////////////// + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,_128>>; + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); +}; + +template +struct DefaultGemm_Simt_OperandA +{ + using SmemLayoutAtom = Layout, + Stride< _1,Int<128 + 4>>>; // Padded + + using SmemCopyAtom = Copy_Atom; + + using GmemTiledCopy = decltype( + make_tiled_copy(Copy_Atom, Element>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + +}; + +template +struct DefaultGemm_Simt_OperandB; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +template +struct DefaultGemm_Simt_OperandB + : DefaultGemm_Simt_OperandA {}; + +} // end namespace detail + +// SIMT Two Stage +template < + class ArchTag, + class ElementA, class LayoutA, + class ElementB, class LayoutB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _8>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A + static constexpr int kAlignmentA = 1; + using DefaultOperandA = detail::DefaultGemm_Simt_OperandA; + using SmemLayoutAtomA = typename DefaultOperandA::SmemLayoutAtom; + using SmemCopyAtomA = typename DefaultOperandA::SmemCopyAtom; + using GmemTiledCopyA = typename DefaultOperandA::GmemTiledCopy; + + // B + static constexpr int kAlignmentB = 1; + using DefaultOperandB = detail::DefaultGemm_Simt_OperandB; + using SmemLayoutAtomB = typename DefaultOperandB::SmemLayoutAtom; + using SmemCopyAtomB = typename DefaultOperandB::SmemCopyAtom; + using GmemTiledCopyB = typename DefaultOperandB::GmemTiledCopy; + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + + +// +// DP4A - int8 Proof-of-concept +// + +// SIMT Two Stage TN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + // NOTE: permuting MMA M mode lets us generate 128b smem loads (LDS.128) but has worst case bank conflicts + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; // Tile of atoms (threads) + + // A (M,K) K-major + using ElementA = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomA = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + // 40% from regular M and N major layout + // using SmemLayoutAtomB = Layout, + // Stride< _1,_128>>; + // 80% from interleaved layouts + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NN - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + + using DispatchPolicy = MainloopSm70TwoStage; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) K-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 4; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage NT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::ColumnMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) M-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Two Stage TT - idp4a +template < + class ArchTag, + class ElementC, class LayoutC> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, ArchTag, + int8_t, cutlass::layout::RowMajor, + int8_t, cutlass::layout::RowMajor, + ElementC, LayoutC, + int32_t> +{ + using TileShape = Shape<_128, _128, _32>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm70TwoStage; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using ElementA = int8_t; + using SmemLayoutAtomA = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 4; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) N-major + using ElementB = int8_t; + using SmemLayoutAtomB = Layout>, + Stride< _4, Stride<_1,_512>>>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride< _1,_32>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// SIMT MULTI STAGE ////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>, // 32x32x1 MMA with perm for load vectorization + Layout,Stride<_2,_1>>,Underscore>>; + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage NN +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::ColumnMajor, + ElementB, cutlass::layout::ColumnMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore,Underscore>>; // 32x16x1 MMA with perm for load vectorization + + // A (M,K) M-major + using SmemLayoutAtomA = Layout>; + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using AlignmentTypeA = cute::uint_byte_t(sizeof(ElementA)) * kAlignmentA>; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentB + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout, + Stride<_16, _1>>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// SIMT Multi Stage TT +template < + class ElementA, + class ElementB, + class ElementC, class LayoutC, + class ElementAccumulator> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassSimt, arch::Sm80, + ElementA, cutlass::layout::RowMajor, + ElementB, cutlass::layout::RowMajor, + ElementC, LayoutC, + ElementAccumulator> +{ + using TileShape = Shape<_128, _128, _16>; + static constexpr int ThreadCount = 256; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom>, + Layout>, // 16x16x1 thread group + Tile,Stride<_2,_1>>,Underscore>>; // 16x32x1 MMA with perm for load vectorization + + // A (M,K) K-major + using SmemLayoutAtomA = Layout, + Stride< _1, Int<128 + 1>>>; // Padded by kAlignmentA + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, ElementA>{}, + Layout, + Stride<_16, _1>>{})); + + // B (N,K) N-major + using SmemLayoutAtomB = Layout>; + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using AlignmentTypeB = cute::uint_byte_t(sizeof(ElementB)) * kAlignmentB>; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, ElementB>{}, + Layout>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + ElementA, TagToStrideA_t, + ElementB, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + ElementC, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TN (K-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; + +/* + using EpilogueOutputOp = epilogue::collective::Epilogue< + epilogue::thread::LinearCombination, + Layout, + Stride< _1,_64>>, // SMEM layout + Copy_Atom,double>, // R2S with tiled_mma layout + decltype(make_tiled_copy(Copy_Atom,double>{},// S2R + Layout, + Stride< _1,_16>>{}, // Thread layout + Layout>{})), // Value layout + Copy_Atom,double> // R2G with S2R_dst layout + >; +*/ +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NN (M-Major A and K-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) K-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{}));// N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 1; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA NT (M-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) M-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Ampere fp64 MMA TT (K-Major A and N-Major B) +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm80, + double, cutlass::layout::RowMajor, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, // Atom + Layout>, // Atom layout + Tile,Stride<_2,_1>>, // 32x32x4 MMA with perm for load vectorization + Layout,Stride<_2,_1>>, + Underscore>>; + + // A (M,K) K-Major + using SmemLayoutAtomA = decltype( + composition(Swizzle<2,0,4>{}, + Layout, + Stride<_1, _4>>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 1; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride<_16, _1>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 1x1 doubles + + // B (N,K) N-Major + using SmemLayoutAtomB = decltype( + composition(Swizzle<2,2,2>{}, + Layout, + Stride< _1,_16>>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, // CopyAtom + Layout, + Stride< _1,_16>>{}, // ThrLayout for CopyAtom + Layout>{})); // Value layout: 2x1 doubles + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = epilogue::collective::DefaultEpilogue< + double, + TagToStrideC_t, + TagToStrideC_t, + epilogue::thread::LinearCombination, + cutlass::gemm::EpilogueDefault>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +// Hopper fp64 MMA TN +template <> +struct DefaultGemmConfigurationToCutlass3Types< + arch::OpClassTensorOp, arch::Sm90, + double, cutlass::layout::RowMajor, + double, cutlass::layout::ColumnMajor, + double, cutlass::layout::ColumnMajor, + double> +{ + using TileShape = Shape<_128, _64, _16>; + static constexpr int ThreadCount = 128; + using DispatchPolicy = MainloopSm80CpAsync<3>; + using TiledMma = TiledMMA< + MMA_Atom, + Layout>>; + + // A (M,K) K-major + using SmemLayoutAtomA = decltype( + make_ordered_layout(Shape<_128,_16>{}, + Step < _2, _1>{})); // M, K + using SmemCopyAtomA = Copy_Atom; + static constexpr int kAlignmentA = 2; + using GmemTiledCopyA = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // B (N,K) K-major + using SmemLayoutAtomB = decltype( + make_ordered_layout(Shape<_64,_16>{}, + Step < _2, _1>{})); // N, K + using SmemCopyAtomB = Copy_Atom; + static constexpr int kAlignmentB = 2; + using GmemTiledCopyB = decltype( + make_tiled_copy(Copy_Atom, double>{}, + Layout, + Stride< _8,_1>>{}, + Layout>{})); + + // Mainloop + using CollectiveMainloop = collective::CollectiveMma< + DispatchPolicy, TileShape, + double, TagToStrideA_t, + double, TagToStrideB_t, + TiledMma, + GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity, // A + GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity // B + >; + + // Epilogue + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + double, double, + double, cutlass::layout::ColumnMajor, 1, + double, cutlass::layout::ColumnMajor, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..89755dd7d3162b114a537e58c6aa33cac80078f9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -0,0 +1,3993 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include // std::lcm + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" +#include "cutlass/detail/collective.hpp" + +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail { + +inline constexpr auto decomp_mode_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::DataParallel) { + return "DataParallel"; + } + else if (mode == Mode::SplitK) { + return "SplitK"; + } + else if (mode == Mode::StreamK) { + return "StreamK"; + } + else { + return "Unknown"; + } + }; + +inline constexpr auto raster_order_to_string = + [] (cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions mode) -> std::string { + using Mode = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + if (mode == Mode::Heuristic) { + return "Heuristic"; + } + else if (mode == Mode::AlongM) { + return "AlongM"; + } + else if (mode == Mode::AlongN) { + return "AlongN"; + } + else { + return "Unknown"; + } + }; + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType>> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +template +struct IsSfdEpi : cute::false_type {}; + +template +struct IsSfdEpi> : cute::true_type {}; + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsLegacyEpiloguePolicy { + static constexpr bool value = false; +}; + +template +struct IsLegacyEpiloguePolicy> { + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool value = cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize>>; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB, + class Enable = void> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + StrideA stride_a; + StrideB stride_b; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (generic)::initialize(problem_shape)"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.resize"); +#endif + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.resize"); +#endif + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor A or B resize threw an unknown exception"); + throw; + } + + try { + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: checked_initialize_tensor threw an unknown exception"); + throw; + } + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Check last error before sync_device()"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: cudaGetLastError() is " << error_str); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.host_data()=" << tensor_A.host_data() << ", tensor_A.device_data()=" << tensor_A.device_data()); + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.host_data()=" << tensor_B.host_data() << ", tensor_B.device_data()=" << tensor_B.device_data()); + } +#endif + try { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_A.sync_device"); +#endif + tensor_A.sync_device(); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: tensor_B.sync_device"); +#endif + tensor_B.sync_device(); + } + catch (cutlass::cuda_exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw cutlass::cuda_exception: " << e); + throw; + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: sync_device() threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop::initialize: Reached end"); +#endif + return true; + } + + Arguments to_args() { + + + // Runtime datatype selection + if constexpr (not cute::is_same_v) { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b + }; + } + else { + + Arguments arguments = + { + tensor_A.device_data(), stride_a, tensor_B.device_data(), stride_b + }; + return arguments; + } + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + + auto dummy_SFA = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto dummy_SFB = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + bool passed = true; + return passed; + } +}; + +// +// Sparse MMA host implementation +// +template< + class Gemm, + class ElementA_, + class ElementB_> +struct HostCollectiveMainloopSparse +{ + + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ArchTag = typename Gemm::ArchTag; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr int MaxSmCount = 16; + + HostCollectiveMainloopSparse( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloopSparse::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + return true; + } +}; + +template< + class ScheduleType_, + class Gemm, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + typename Gemm::CollectiveMainloop::DispatchPolicy>>> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse MMA input Operands : A_compressed, B, metadata +// +// Structured Sparse Gemm Input Operands + +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> + : HostCollectiveMainloopSparse +{ + using HostCollectiveMainloopSparse::HostCollectiveMainloopSparse; +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Sparse Gemm Input Operands : A , B, E +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_ >; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideB stride_b; + + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M * L, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N * L); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + tensor_A.sync_device(); + tensor_B.sync_device(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A.device_data()), stride_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Structured Sparse Gemm Input Operands : A_compressed, B, metadata, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + typename ElementA_, + typename ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + // CuTe layout A for the kernel's sparse tensorA. + using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + + using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; + // CuTe layout E for the kernel's metadata tensor. + using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + + // The following typenames are for the reference host tensors. They are non-sparse tensors. + using LayoutTagA = decltype(SparseConfig::deduce_layoutA_tag(LayoutA{})); + using StrideA = cutlass::gemm::TagToStrideA_t; + // We don't care about the actual strideE for the host tensor, but just need one to allocate memory. + using StrideE = StrideA; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + using LayoutTagE = cutlass::detail::StrideToLayoutTagA_t; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm100>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + StrideA stride_a; + StrideA stride_a_compressed; + StrideB stride_b; + StrideE stride_e; + + LayoutA layout_a; + LayoutE layout_e; + LayoutSFA layout_sfa; + LayoutSFB layout_sfb; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + typename LayoutTagE::Stride stride_factor_E; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_SFA; + cutlass::HostTensor tensor_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride(), + typename LayoutTagE::Stride stride_factor_E_ = typename LayoutTagE::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_E(stride_factor_E_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_size) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveMainloop (KernelSparseTmaWarpSpecializedBlockScaledSm100)::initialize"); +#endif + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + int KAlignedE = compressor_utility.get_metadata_k_physical(); + int MAlignedE = compressor_utility.get_metadata_m_physical(); + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + int KAlignedAC = compressor_utility.get_tensorA_k_physical(); + int MAlignedAC = compressor_utility.get_tensorA_m_physical(); + + stride_a_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KAlignedAC, L)); + stride_e = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(MAlignedE, KAlignedE, L)); + + auto a_coord = cutlass::make_Coord(M * L, K); + auto b_coord = cutlass::make_Coord(K, N * L); + auto e_coord = cutlass::make_Coord(MAlignedE * L, KAlignedE); + auto a_comp_coord = cutlass::make_Coord(MAlignedAC * L, KAlignedAC); + + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + tensor_A_Comp.resize(a_comp_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); + tensor_E.resize(e_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2022)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2021)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = ElementA(1); + tensor_B.host_view().at({0, 0}) = ElementB(1); + + compressor_utility.structure_sparse_zero_mask_fill(tensor_A.host_data(), static_cast(seed + 2023)); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_E.sync_device(); + tensor_A_Comp.sync_device(); + + cutlass::Status status {cutlass::Status::kSuccess }; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {M, N, K, L}, + {tensor_A.device_data(), + stride_a, + tensor_A_Comp.device_data(), + tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + return false; + } + + status = compressor_op.run(); + + auto result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + layout_a = SparseConfig::fill_layoutA(problem_shape_MNKL); + layout_e = SparseConfig::fill_layoutE(problem_shape_MNKL); + + tensor_E.sync_host(); + tensor_A_Comp.sync_host(); + + using namespace cute; + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL); + layout_sfb = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{} * L, k_blks * Blk_SF{}); + + tensor_SFA.resize(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A)); + tensor_SFB.resize(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B)); + + EXPECT_TRUE(initialize_tensor(tensor_SFA.host_view(), init_A, seed + 2024)); + EXPECT_TRUE(initialize_tensor(tensor_SFB.host_view(), init_B, seed + 2025)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_SFA.host_view().at({0, 0}) = ElementSF(1); + tensor_SFB.host_view().at({0, 0}) = ElementSF(1); + + tensor_SFA.sync_device(); + tensor_SFB.sync_device(); + + return true; + } + + Arguments to_args() { + using ArrayElementA = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Gemm::GemmKernel::CollectiveMainloop::ArrayElementB; + return { + reinterpret_cast(tensor_A_Comp.device_data()), layout_a, + reinterpret_cast(tensor_B.device_data()), stride_b, + tensor_E.device_data(), layout_e, + tensor_SFA.device_data(), layout_sfa, + tensor_SFB.device_data(), layout_sfb + }; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + auto A = make_tensor(make_iterator(tensor_A.host_data()), + make_layout(make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(tensor_SFA.host_data(), layout_sfa); + + auto B = make_tensor(make_iterator(tensor_B.host_data()), + make_layout(make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(tensor_SFB.host_data(), layout_sfb); + + // return {A, SfA, B, SfB}; + cutlass::reference::host::GettMainloopParams + mainloop_params{A, SfA, B, SfB}; + return mainloop_params; + } + + void print_tensors(std::ofstream& file) { + file << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nSFA =\n" << tensor_SFA.host_view() + << "\nSFB =\n" << tensor_SFB.host_view(); + } + + bool compare_reference( + cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFA.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_SFB.host_view()), 0); + return true; + } +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride(), + typename Base::LayoutTagE::Stride stride_factor_E_ = typename Base::LayoutTagE::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, + stride_factor_B_, + stride_factor_E_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + cutlass::HostTensor tensor_C; + // Inputs + ElementScalar alpha; + ElementScalar beta; + + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + { + const bool init_succeeded = initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not init_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(init_succeeded); + } + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveDefaultEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + auto [M, N, K, L] = problem_shape_MNKL; + + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + static constexpr bool IsLegacy = detail::IsLegacyEpiloguePolicy::value; + + // FFMA2 SGEMM uses ThreadEpilogueOp for bias and relu support instead of FusionOp, so we compose LinCombPerRowBiasEltAct FusionOp by hand to test the functionality. + static constexpr bool IsFfma2Kernel = cute::is_same_v; + using FusionOp = cute::conditional_t, + typename Gemm::EpilogueOutputOp>; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + static constexpr bool IsKMajorSFD = cute::is_same_v; + using ElementSFD = non_void_t; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + cutlass::HostTensor tensor_SFD; + cutlass::HostTensor reference_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsRowBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsColBiasEnabled = FusionOp::IsPerColBiasSupported; + static_assert(not (IsColBiasEnabled && IsRowBiasEnabled)); + + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsPerColScaleEnabled = FusionOp::IsPerColScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + StrideC stride_c; + StrideD stride_d; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor tensor_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // References + cutlass::HostTensor reference_dbias; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If vector scale is supported and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize(problem_size, alpha, beta)"); +#endif + // Initialize Epilogue tensors + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::size<0>(problem_shape_MNKL); + auto N = cute::size<1>(problem_shape_MNKL); + auto K = cute::size<2>(problem_shape_MNKL); + auto L = cute::size<3>(problem_shape_MNKL); + + stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M * L, N); + try { + tensor_C.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C)); + tensor_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D)); + reference_D.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: resizing tensors threw an unknown exception"); + throw; + } + + try { + const bool initialize_tensor_C_succeeded = + initialize_tensor(tensor_C.host_view(), init_C, seed + 2020); + if (not initialize_tensor_C_succeeded) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor returned false"); + } + EXPECT_TRUE(initialize_tensor_C_succeeded); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: initialize_tensor threw an unknown exception"); + throw; + } + + tensor_C.host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + try { + tensor_C.sync_device(); + tensor_D.sync_device(); + } + catch (std::exception const& e) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("HostCollectiveEpilogue::initialize: sync_device() threw an unknown exception"); + throw; + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + auto row_vector_coord = cutlass::make_Coord(N); + auto batch_vector_coord = cutlass::make_Coord(L); + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + // scalars + if (vector_scale_mode == VectorScale::DISABLED) { + // batched scalars + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // non-batched scalars + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + // batched vectors + else { + auto batched_vector_coord = cutlass::make_Coord((IsPerRowScaleEnabled ? M : N) * L); + alpha.resize(batched_vector_coord, true); + beta.resize(batched_vector_coord, true); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (beta_ != ElementScalar(0)) { + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + else { + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + // Set alpha beta for different batches. + alpha.resize(batch_vector_coord, true); + beta.resize(batch_vector_coord, true); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + for (int l = 0; l < L; ++l) { + beta.host_view().at(cutlass::make_Coord(l)) = beta_ + ElementScalar(l); + } + } + else { + alpha.resize(scalar_coord, false); + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + bias.resize(IsRowBiasEnabled ? col_vector_coord : row_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + EXPECT_TRUE(initialize_tensor(tensor_Aux.host_view(), init_C, seed + 2023)); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + } + + if constexpr (IsAuxOutEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + reference_Aux.resize(aux_coord, aux_layout, false); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + if constexpr (IsKMajorSFD) { + return cutlass::make_Coord(m_blks * Blk_MN{} * L, n_blks * Blk_SF{}); + } + else { + return cutlass::make_Coord(m_blks * Blk_SF{} * L, n_blks * Blk_MN{}); + } + }(); + tensor_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D)); + reference_SFD.resize(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false); + tensor_SFD.sync_device(); + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) { + tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + } + + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + } + + bool passed = equality_check(reference_D.host_view(), tensor_D.host_view()); + if(!passed) { + #if 0 + auto [M, N, K, L] = problem_shape_MNKL; + auto ref = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto comp = cute::make_tensor(detail::make_iterator(tensor_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + for(int i=0; i(ElementD(ref(i, j, l))) != static_cast((ElementD(comp(i, j, l))))) { + printf(" ref: %f comp: %f\n", i, j, l, static_cast(ElementD(ref(i, j, l))), static_cast((ElementD(comp(i, j, l))))); + } + } + } + } + #endif + std::cout<<"D is incorrect"<(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + Arguments arguments = + { + {}, + tensor_C.device_data(), stride_c, tensor_D.device_data(), stride_d + }; + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = tensor_Aux.device_data(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.alpha_ptr = alpha.device_data(); + // Only initializing beta/beta_ptr for non-void source + if constexpr (not cute::is_void_v) { + fusion_args.beta = beta.at(coord_0); + fusion_args.beta_ptr = beta.device_data(); // if vector_scale_mode is true this is nullptr + } + + if constexpr (IsPerRowScaleEnabled) { + int32_t m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + fusion_args.dBeta = cute::make_stride(bool(m_stride),cute::_0{}, l_stride); + } + else if constexpr (IsPerColScaleEnabled) { + int32_t n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int64_t l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + fusion_args.dAlpha = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + fusion_args.dBeta = cute::make_stride(cute::_0{}, bool(n_stride), l_stride); + } + else { + if constexpr (not IsFfma2Kernel) { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + if (L > 1) { + fusion_args.dAlpha = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + fusion_args.dBeta = cute::make_stride(cute::_0{},cute::_0{}, int64_t(1)); + } + } + } + } + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + auto init_activation_args = [] (auto activation, auto& args) { + using Activation = cute::remove_cvref_t; + if constexpr (cute::is_same_v>) { + args.lower_bound = 0; // Treat Clamp as ReLU + args.upper_bound = cutlass::platform::identity_for_minimum(); + } + if constexpr (cute::is_same_v>) { + args.scale = ElementCompute(1); + } + }; + + if constexpr (not cute::is_same_v>) { + init_activation_args(ActivationFunctor{}, fusion_args.activation); + } + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + + if constexpr (IsBlockScaleSupported) { + arguments.thread.block_scale_factor_ptr = tensor_SFD.device_data(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_size) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(IsRowBiasEnabled ? M : N))); + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), + cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); + auto Valpha = [&](){ + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto Vbeta = [&]() { + if constexpr (IsPerRowScaleEnabled) { + int m_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? M : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(m_stride, cute::_0{}, l_stride))); + } + else if constexpr (IsPerColScaleEnabled) { + int n_stride = vector_scale_mode == VectorScale::ENABLED ? 1 : 0; + int l_stride = vector_scale_mode == VectorScale::ENABLED ? N : (use_device_scalars == ScalarLoc::ON_DEVICE ? 1 : 0); + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, n_stride, l_stride))); + } + else { + return cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, L), make_stride(cute::_0{}, cute::_0{}, cute::_1{}))); + } + }(); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(reference_SFD.host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor, + decltype(SfD), + Int, + cutlass::plus, + IsColBiasEnabled + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsRowBiasEnabled or IsColBiasEnabled or IsDeBiasEnabled) + { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled or IsPerColScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + else { + if (use_device_scalars == ScalarLoc::ON_DEVICE) { + epilogue_params.Valpha = Valpha; + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_size, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::initialize(problem_size, alpha, beta)"); +#endif + collective_mma_inputs.initialize(problem_size); + collective_epilogue.initialize(problem_size, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta) + { + auto [M, N, K, L] = problem_shape_MNKL; + + bool passed = collective_mma_inputs.compare_reference(problem_shape_MNKL); + passed &= collective_epilogue.compare_reference(problem_shape_MNKL, alpha, beta); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file); + collective_epilogue.print_tensors(file); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_size, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto mainloop_params = collective_mma_inputs.to_host_args(problem_size); + auto epilogue_params = collective_epilogue.to_host_args(problem_size); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + bool passed = compare_reference(problem_shape_MNKL, alpha, beta); + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + bool profile( + ProblemShapeType problem_size, + int iterations, + Gemm& gemm_op, + typename Gemm::Arguments& arguments, + cutlass::device_memory::allocation& workspace) { + int M = cute::size<0>(problem_size); + int N = cute::size<1>(problem_size); + int K = cute::size<2>(problem_size); + int L = 1; + if constexpr(cute::rank(ProblemShapeType{}) == 4) { + L = cute::size<3>(problem_size); + } + + + cutlass::Status status; + // + // Run the GEMM + // + cudaError_t result; + + for (int iter = 0; iter < iterations; ++iter) { + status = gemm_op(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + return false; + } + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + detail::Iterations iterations = detail::Iterations{}, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run"); +#endif + + // Fail test if insufficient CUDA device + if (!sufficient()) { + CUTLASS_TRACE_HOST("TestbedImpl::run: Test failed due to insufficient CUDA device"); + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: sufficient() returned true"); + } +#endif + + try { + const bool initialized = this->initialize(problem_size, alpha, beta); + if (not initialized) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize returned false"); + std::cerr << "Initialization failed \n"; + return false; + } + } + catch ([[maybe_unused]] std::exception const& e) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an exception: " << e.what()); + throw; + } + catch (...) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize threw an unknown exception"); + throw; + } + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: this->initialize() returned true"); +#endif + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + this->sm_count = std::min(MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = this->sm_count; + } + else { + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(); + + + if constexpr (IsRuntimeDataType) { + mainloop_args.runtime_data_type_a = runtime_input_datatype_a; + mainloop_args.runtime_data_type_b = runtime_input_datatype_b; + } + + + arguments = + { + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + mainloop_args, + collective_epilogue.to_args(problem_size), + hw_info, + scheduler_args + }; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Creating gemm_op"); +#endif + Gemm gemm_op; + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling Gemm::get_workspace_size"); +#endif + size_t workspace_size = Gemm::get_workspace_size(arguments); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Allocating workspace of size " << workspace_size); +#endif + cutlass::device_memory::allocation workspace(workspace_size); + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.can_implement"); +#endif + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + std::cerr << "This test is not supported: " << error_str << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling profile"); +#endif + return profile(problem_size, static_cast(iterations), gemm_op, arguments, workspace); + } + else { + cudaError_t result; +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.initialize"); +#endif + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling gemm_op.run"); +#endif + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling cudaDeviceSynchronize"); +#endif + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaDeviceSynchronize reports non-success"); + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Calling this->verify"); +#endif + bool passed = this->verify(problem_size, alpha, beta); + if (!passed) { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify FAILED"); + cudaError_t error = cudaGetLastError(); + const auto error_str = cudaGetErrorString(error); + CUTLASS_TRACE_HOST("TestbedImpl::run: cudaGetLastError() is " << error_str); + + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + CUTLASS_TRACE_HOST("TestbedImpl::run: this->verify passed"); + } +#endif + +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("TestbedImpl::run: Reached end"); +#endif + return passed; + } + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB + , typename RuntimeDatatypeA = void* + , typename RuntimeDatatypeB = void* +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + , RuntimeDatatypeA + , RuntimeDatatypeB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + bool profiling = false, + detail::Iterations iterations = detail::Iterations{} + , RuntimeDatatypeA runtime_input_datatype_a = {} + , RuntimeDatatypeB runtime_input_datatype_b = {} + ) + { + return impl_.run( + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits, decomposition_mode + , runtime_input_datatype_a, runtime_input_datatype_b + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestGemmPerf3x(int iterations = 20) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalar = ElementAccumulator; + bool passed = true; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector problem_size_m = { 4608 }; + std::vector problem_size_n = { 4608 }; + std::vector problem_size_k = { 8192 }; + + Testbed3x testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(0), + RasterOrderOptions{}, detail::MaxSwizzleSize(1), detail::Splits{1}, DecompositionMode{}, + true, // profiling + detail::Iterations{iterations}); + + if (!passed) { + return false; + } + } + } + } + + return true; +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +template < + typename Gemm, + typename RuntimeDataTypeA, + typename RuntimeDataTypeB, + bool force_legacy_epilogue = false> +bool TestRuntimeDataTypeSmall( + RuntimeDataTypeA runtime_input_datatype_a, + RuntimeDataTypeB runtime_input_datatype_b, + double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, VectorScale vector_scale_mode = VectorScale::ENABLED, std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + + using InternalElementA = typename Gemm::GemmKernel::ElementA; + using InternalElementB = typename Gemm::GemmKernel::ElementB; + + CtaShape_MNK cta_shape; + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment * MultiplierOffsetK, 512 + max_alignment * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + [[maybe_unused]] constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment; + + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + } + for (auto splits : problem_splits) { + + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + else { + std::cout << "Unsupported configuration for runtime datatype MXFP4." << std::endl; + return false; + } + } + + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + static_assert((cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v), + "Runtime datatype must be selected with an appropriate static umbrella data type."); + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // f6xf4 + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e2m1_e2m1 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E2M1 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M1) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e4m3_e3m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E3M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupport + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else if constexpr (cute::is_same_v && + cute::is_same_v) { + // e3m2_e2m3 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E3M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E2M3) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + else + if constexpr (cute::is_same_v && + cute::is_same_v) { + // e5m2_e5m2 + if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2) { + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e5m2 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E5M2){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e5m2_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E5M2 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // e4m3_e4m3 + else if (runtime_input_datatype_a == cute::UMMA::MXF8F6F4Format::E4M3 && + runtime_input_datatype_b == cute::UMMA::MXF8F6F4Format::E4M3){ + Testbed3x testbed(check_relative_equality, + use_device_scalars, + vector_scale_mode); + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + RasterOrderOptions::Heuristic, // raster_order + detail::MaxSwizzleSize(1), + splits, + decomp_mode, + false, + detail::Iterations{}, + runtime_input_datatype_a, + runtime_input_datatype_b + ); + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + // Unsupported + else { + std::cout << "Unsupported configuration for runtime datatype Mxf8f6f4." << std::endl; + return false; + } + } + + else { + static_assert(cutlass::detail::dependent_false, + "Unsupported configuration for runtime datatype."); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmall(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + static constexpr int SmCount = 16; + static constexpr int MultiplierOffsetM = 1; + static constexpr int MultiplierOffsetN = 2; + static constexpr int MultiplierOffsetK = 3; + int max_alignment_k = 0; + int max_alignment_m = 0; + int max_alignment_n = 0; + + if constexpr (apply_alignment_offset) { + max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_n = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + max_alignment_m = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + } + // Alignment for SFD + if constexpr (detail::IsSfdEpi::value) { + using GmemLayoutTagScalefactor = typename Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::GmemLayoutTagScalefactor; + constexpr int SFDVecSize = Gemm::GemmKernel::CollectiveEpilogue::FusionCallbacks::Operation::SFVecSize; + if constexpr (cute::is_same_v) { + max_alignment_n = std::lcm(max_alignment_n, SFDVecSize); + } + else { + max_alignment_m = std::lcm(max_alignment_m, SFDVecSize); + } + } + + float waves[] = {0.5, 1.25, 2.5}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + problem_size_k = {256 + max_alignment_k * MultiplierOffsetK, 512 + max_alignment_k * MultiplierOffsetK}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + } + bool passed = true; + + std::vector raster_order_options = {RasterOrderOptions::Heuristic}; + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + int num_grid = int(wave * SmCount); + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = num_grid / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = num_grid / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) + MultiplierOffsetM * max_alignment_m; + int n = grid_n * cute::size<1>(cta_shape) + MultiplierOffsetN * max_alignment_n; + int l = test_batched_alpha_beta && wave == waves[0] && k == problem_size_k[0] ? 2 : 1; // only test the smallest problem size + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, l}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (DecompositionMode decomp_mode : decomposition_modes) { + for (RasterOrderOptions raster_order : raster_order_options) { + std::vector problem_splits = {detail::Splits{1}}; + if constexpr (UsesStreamKScheduler) { + if (decomp_mode == DecompositionMode::SplitK) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{4}); + } + } + for (auto splits : problem_splits) { + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, // raster_order + detail::MaxSwizzleSize(0), + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + EXPECT_TRUE(passed) << "TestSmall: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k << ", l: " << l + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: " << detail::raster_order_to_string(raster_order) + << ", max_swizzle_size: 1" + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << l << " FAILED.\n"; + return false; + } + } // splits + } // raster_order + } // decomposition_mode + } // k + } // waves + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + return TestSmall(alpha, + beta, + check_relative_equality, + use_device_scalars, + vector_scale_mode, + override_problem_size_k); +} + + + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_HOST, VectorScale::DISABLED); + + int max_alignment_m = std::max({Gemm::kAlignmentA, Gemm::kAlignmentC, Gemm::kAlignmentD}); + int max_alignment_n = std::max({Gemm::kAlignmentB, Gemm::kAlignmentC, Gemm::kAlignmentD}); + if constexpr (std::is_base_of_v) { + max_alignment_m = std::max(max_alignment_m, Gemm::EpilogueOutputOp::AlignmentAux); + max_alignment_n = std::max(max_alignment_n, Gemm::EpilogueOutputOp::AlignmentAux); + } + std::vector problem_size_m = {max_alignment_m, 512 - 3 * max_alignment_m}; + std::vector problem_size_n = {max_alignment_n, 512 - 2 * max_alignment_n}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + int max_alignment_k = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_k = {max_alignment_k, TileShapeK * (Stages + 1) - max_alignment_k}; + + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + std::vector decomposition_modes = {DecompositionMode::Heuristic}; + std::vector problem_splits = {detail::Splits{1}}; + static constexpr bool UsesStreamKScheduler = cute::is_same_v; + if constexpr (UsesStreamKScheduler) { + problem_splits.push_back(detail::Splits{2}); + problem_splits.push_back(detail::Splits{3}); + + decomposition_modes.push_back(DecompositionMode::DataParallel); + decomposition_modes.push_back(DecompositionMode::SplitK); + decomposition_modes.push_back(DecompositionMode::StreamK); + + // Use larger K sizes for stream-K tests + static constexpr int min_tiles_per_sk_unit = cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::min_iters_per_sk_unit_; + problem_size_k = {TileShapeK * min_tiles_per_sk_unit, TileShapeK * 3 * min_tiles_per_sk_unit - max_alignment_k}; + } + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; + std::vector max_swizzle_sizes{detail::MaxSwizzleSize{1}, detail::MaxSwizzleSize{4}}; + + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (auto raster_order : raster_orders) { + for (auto max_swizzle_size : max_swizzle_sizes) { + for (DecompositionMode decomp_mode : decomposition_modes) { + + std::vector problem_splits = {detail::Splits{1}}; + if (decomp_mode == DecompositionMode::Heuristic || decomp_mode == DecompositionMode::SplitK) { + auto max_splits = (k + TileShapeK - 1) / TileShapeK; + if (max_splits > 2) { + problem_splits.push_back(detail::Splits{2}); + } + if (max_splits > 3) { + problem_splits.push_back(detail::Splits{3}); + } + + problem_splits.push_back(detail::Splits{max_splits}); + + // Test the case in which we ask for more splits than there are K tiles in the GEMM. In this + // case, split-K will fall back to a splitting factor of `max_splits`. + problem_splits.push_back(detail::Splits{max_splits + 1}); + } + for (auto splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + try { + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + max_swizzle_size, + splits, + decomp_mode + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} threw an exception (unknown)"; + throw; + } + + EXPECT_TRUE(passed) << "TestAll: testbed.run {" + << "m: " << m << ", n: " << n << ", k: " << k + << ", alpha: " << alpha << ", beta: " << beta + << ", raster_order: ???" + << ", max_swizzle_size: " << static_cast(max_swizzle_size) + << ", splits: " << static_cast(splits) + << ", decomp_mode: " << detail::decomp_mode_to_string(decomp_mode) + << "} failed"; + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // splits + } // decomposition_mode + } // max_swizzle_size + } // raster_order + } // k + } // n + } // m + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment_m, 256 + max_alignment_n, 160 + max_alignment_k, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +template +bool TestAllBiasElementwise(double alpha = 1.0, double beta = cute::is_same_v ? 0.0 : 1.0, CheckEquality check_relative_equality = CheckEquality::EXACT) { + return TestAll(alpha, beta, check_relative_equality); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f18a7b39cbfe7dfb8d3251b2750e49261522de8a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_evt.hpp @@ -0,0 +1,1742 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed and host reference for EVT unittest +*/ + + +#pragma once +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +/// Host-side tapply, tapply in cute is HOST_DEVICE +template +constexpr auto +tapply(T&& t, F&& f, G&& g, cute::seq) +{ + return g(f(std::get(static_cast(t)))...); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT: Base class for EVT Node + +template < class ElementCompute_ > +class HostEVTNodeBase { +public: + using ElementCompute = ElementCompute_; + +private: + bool check_relative_equality_; + // Factors used for calculating relative equality. These default + // values are borrowed from those used by default in the CUTLASS + // profiler for performing relative equality checks. + float epsilon_ = 0.05f; + float nonzero_floor_ = 1.0f / 256.0f; + +public: + HostEVTNodeBase(){} + HostEVTNodeBase(bool check_relative_equality): + check_relative_equality_(check_relative_equality) { } + + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + if (check_relative_equality_) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, Element(epsilon_), Element(nonzero_floor_) + ); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + void* get_tensor_C_ptr() { + return nullptr; + } + + void* get_tensor_D_ptr() { + return nullptr; + } + + bool compare_reference(std::stringstream& error_ss) { + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Accumulator + +template< class ElementCompute = float > +class HostAccumulator: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + struct Arguments { }; + +public: + HostAccumulator(){} + template + HostAccumulator(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality) {} + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + cutlass::NumericConverter accumulator_converter; + return accumulator_converter(acc); + } + + Arguments get_arguments() { + return Arguments{}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Broadcast + +template < + int Value, + int BroadcastCount = 1, + class StrideMNL = cute::Stride, + template class ReductionFn = cutlass::multiplies, + class ElementCompute = float +> +class HostScalarBroadcast : public HostEVTNodeBase { +public: + + using Base = HostEVTNodeBase; + struct Arguments { + ElementCompute scalar[BroadcastCount] = {0}; + ElementCompute const* scalar_ptrs[BroadcastCount] = { nullptr }; + StrideMNL dScalar[BroadcastCount] = {}; + }; +private: + ElementCompute scalar_{}; + StrideMNL dScalar{}; + ElementCompute scalar_reduced_{}; +public: + HostScalarBroadcast(){} + + template + HostScalarBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality), scalar_(ElementCompute(Value)) { + scalar_ = ElementCompute(Value); + scalar_reduced_ = scalar_; + for (int i = 1; i < BroadcastCount; ++i) { + scalar_reduced_ = ReductionFn{}(scalar_reduced_, ElementCompute(Value)); + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return scalar_reduced_; + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss << "Scalar: " << float(scalar_) << "\n\n"; + return true; + } + + Arguments get_arguments() { + if constexpr (BroadcastCount == 1) + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + else if constexpr (BroadcastCount == 2) + return Arguments{{scalar_, scalar_}, {nullptr, nullptr}, {dScalar, dScalar}}; + else if constexpr (BroadcastCount == 3) + return Arguments{{scalar_, scalar_, scalar_}, {nullptr, nullptr, nullptr}, {dScalar, dScalar, dScalar}}; + else + return Arguments{{scalar_}, {nullptr}, {dScalar}}; + } + + auto get_flatten_arguments() { + if constexpr (BroadcastCount == 1) { + return cute::make_tuple(scalar_, nullptr); + } + else if constexpr (BroadcastCount == 2) { + return cute::make_tuple(scalar_, scalar_, nullptr, nullptr); + } + else if constexpr (BroadcastCount == 3) { + return cute::make_tuple(scalar_, scalar_, scalar_, nullptr, nullptr, nullptr); + } + else { + return cute::make_tuple(scalar_, nullptr); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostRowBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int N_; +public: + HostRowBroadcast(){} + template + HostRowBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(N_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + + return bias_converter_(TensorBias(1, n + n_b)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerColumnBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Broadcast +template < + typename ElementBias_, + typename StrideMNL = cute::Stride, + typename ElementCompute = float +> +class HostColBroadcast: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementBias = ElementBias_; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + ElementBias const* ptr_row = nullptr; + ElementBias null_default = ElementBias(0); + StrideMNL dRow = {}; + }; +private: + cutlass::NumericConverter bias_converter_; + cutlass::HostTensor bias_; + int M_; +public: + HostColBroadcast(){} + template + HostColBroadcast(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + bias_.resize(cutlass::Coord<1>(M_)); + + EXPECT_TRUE( + detail::initialize_tensor( + bias_.host_view(), cutlass::Distribution::Uniform, + seed + ) + ); + bias_.sync_device(); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + auto TensorBias = cute::make_tensor(bias_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + + return bias_converter_(TensorBias(m + m_b, 1)); + } + + bool compare_reference(std::stringstream& error_ss) { + error_ss + << "PerRowBias = \n" << bias_.host_view() << "\n\n"; + return true; + } + + Arguments get_arguments() { + return {bias_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(bias_.device_data(), ElementBias(0), StrideMNL{}); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Load + +template < + typename ElementAuxLoad_, + typename LayoutTagAux_, + bool isC = false, + typename ElementCompute = float +> +class HostAuxLoad: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ElementAuxLoad = ElementAuxLoad_; + using LayoutTagAux = LayoutTagAux_; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + ElementAuxLoad const *ptr_aux = nullptr; + ElementAuxLoad null_default = ElementAuxLoad(0); + StrideAux dAux = {}; + }; + + struct Arguments_C {}; + + using Arguments = cute::conditional_t; + +private: + cutlass::NumericConverter aux_load_converter_; + cutlass::HostTensor tensor_aux_load_; + + int M_, N_, L_; + + StrideAux stride_aux_; +public: + HostAuxLoad(){} + template + HostAuxLoad(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + : Base(check_relative_equality) { + auto problem_shape_NMKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_NMKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_load_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + EXPECT_TRUE( + detail::initialize_tensor( + tensor_aux_load_.host_view(), + cutlass::Distribution::Uniform, + seed + ) + ); + tensor_aux_load_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + + auto TensorAuxLoad = cute::make_tensor(tensor_aux_load_.host_data(), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + return aux_load_converter_(TensorAuxLoad(m + m_b, n + n_b, l)); + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (!isC) { + error_ss + << "AuxLoad = \n" << tensor_aux_load_.host_view()<< "\n\n"; + } + return true; + } + + void* get_tensor_C_ptr() { + if constexpr (isC) { + return static_cast(tensor_aux_load_.device_data()); + } + else { + return nullptr; + } + } + + Arguments get_arguments() { + if constexpr (isC) + return {}; + else + return {tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_}; + } + + auto get_flatten_arguments() { + if constexpr (isC) + return cute::make_tuple(); + else + return cute::make_tuple(tensor_aux_load_.device_data(), ElementAuxLoad(0), stride_aux_); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Compute + +template +T* findNonNullPtr(T* first_ptr) { + return first_ptr; +} + +template +T* findNonNullPtr(T* first_ptr, Args... args) { + if (first_ptr) { + return first_ptr; + } + return findNonNullPtr(args...); +} + +template < + template class ComputeOp_, + typename ElementCompute = float +> +class HostCompute: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using ComputeOp = ComputeOp_; + + struct Arguments { + struct OpArgs {} op; + }; +private: + ComputeOp op_; +public: + HostCompute(){} + template + HostCompute(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, Args... frg_inputs) { + return op_(frg_inputs...); + } + + Arguments get_arguments(){ + return {}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Aux Store + +template < + class ElementAuxStore_, + typename LayoutTagAux_, + bool isD = false, + bool isRelu = false, + typename ElementCompute = float +> +class HostAuxStore: public HostEVTNodeBase { +public: + using ElementAuxStore = ElementAuxStore_; + using LayoutTagAux = LayoutTagAux_; + + using Base = HostEVTNodeBase; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + struct Arguments_Aux { + struct OpArgs { + ElementAuxStore* ptr_aux = nullptr; + StrideAux dAux = {}; + } op; + }; + + struct Arguments_D {}; + + using Arguments = cute::conditional_t; + + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_aux_store_; + cutlass::HostTensor reference_aux_store_; + int M_, N_, L_; + StrideAux stride_aux_; +public: + HostAuxStore(){} + template + HostAuxStore(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M_, N_, K, L_] = problem_shape_MNKL; + auto aux_coord = cutlass::make_Coord(M_ * L_, N_); + tensor_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + + reference_aux_store_.resize( + aux_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory( + aux_coord, typename LayoutTagAux::Stride() + ) + ); + tensor_aux_store_.sync_device(); + stride_aux_ = cutlass::make_cute_packed_stride(StrideAux{}, cute::make_shape(M_, N_, L_)); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + + auto TensorAuxStore = cute::make_tensor(detail::make_iterator(static_cast(reference_aux_store_.host_data())), + cute::make_layout(cute::make_shape(M_, N_, L_), stride_aux_)); + if constexpr (isRelu) + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result >= 0); + else + TensorAuxStore(m + m_b, n + n_b, l) = destination_converter_(child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_aux_store_.sync_host(); + + bool equal = this->equality_check(reference_aux_store_.host_view(), tensor_aux_store_.host_view()); + if (!equal) { + error_ss + << "\n\nReference =\n" << reference_aux_store_.host_view() + << "\n\nComputed =\n" << tensor_aux_store_.host_view() << "\n\n"; + } + return equal; + } + + void* get_tensor_D_ptr() { + if constexpr (isD) + return static_cast(tensor_aux_store_.device_data()); + else + return nullptr; + } + + Arguments get_arguments() { + if constexpr (isD) { + return {}; + } + else { + return {tensor_aux_store_.device_data(), stride_aux_}; + } + } + + auto get_flatten_arguments() { + if constexpr (isD) { + return cute::make_tuple(); + } + else { + return cute::make_tuple(tensor_aux_store_.device_data(), stride_aux_); + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Row Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostRowReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_row = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_row_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_row_reduce_; + int N_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostRowReduce(){} + template + HostRowReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + N_ = cute::get<1>(problem_shape_MNKL); + if constexpr (FinalReduction) { + tensor_row_reduce_.resize(cutlass::Coord<1>(N_)); + reference_row_reduce_.resize(cutlass::Coord<1>(N_)); + reduce_buffer_.resize(cutlass::Coord<1>(N_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile); + extent_n_ = cute::get<1>(NumTile) * TileN; + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_row_reduce_.resize(shape); + reference_row_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + if constexpr (FinalReduction) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}, N_))); + TensorRowReduce(1, n + n_b) = reduce_fn_(TensorRowReduce(1, n + n_b), child_0_result); + } + else { + auto TensorRowReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(extent_n_, 1, extent_m_ * extent_l_) + ) + ); + TensorRowReduce((m+m_b)/TileM, n+n_b, l) = reduce_fn_(TensorRowReduce((m+m_b)/TileM, n+n_b, l), child_0_result); + } + + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_row_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_row_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_row_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t n = 0; n < size(TensorRowReduce); n ++) { + TensorRowReduce(n) = destination_converter_(TensorReduceBuffer(n)); + } + + bool equal = this->equality_check(reference_row_reduce_.host_view(), tensor_row_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nRow Reduce Reference =\n" << reference_row_reduce_.host_view() + << "\n\nRow Reduce Computed =\n" << tensor_row_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_row_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Column Reduce + +template < + template class ReduceFn, + typename ElementReduce, + bool FinalReduction = true, // Should match the FinalReduction in Device type + typename CtaTileShapeMNK = cute::Shape, + typename ElementCompute = float +> +class HostColumnReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementDst = cute::conditional_t; + + static constexpr int TileM = cute::get<0>(CtaTileShapeMNK{}); + static constexpr int TileN = cute::get<1>(CtaTileShapeMNK{}); + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_col = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dRow = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_column_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_column_reduce_; + int M_; + ReduceFn reduce_fn_; + + int extent_m_; + int extent_n_; + int extent_l_; +public: + HostColumnReduce(){} + template + HostColumnReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + M_ = cute::get<0>(problem_shape_MNKL); + + if constexpr (FinalReduction) { + tensor_column_reduce_.resize(cutlass::Coord<1>(M_)); + reference_column_reduce_.resize(cutlass::Coord<1>(M_)); + reduce_buffer_.resize(cutlass::Coord<1>(M_)); + } + else { + auto NumTile = cute::ceil_div(cute::select<0,1,3>(problem_shape_MNKL), cute::take<0,2>(CtaTileShapeMNK{})); + extent_m_ = cute::get<0>(NumTile) * TileM; + extent_n_ = cute::get<1>(NumTile); + extent_l_ = cute::get<2>(NumTile); + auto shape = cutlass::make_Coord(extent_m_ * extent_n_ * extent_l_); + tensor_column_reduce_.resize(shape); + reference_column_reduce_.resize(shape); + reduce_buffer_.resize(shape); + } + + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorColReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(M_, cute::_1{}))); + if constexpr (FinalReduction) { + TensorColReduce(m + m_b, 1) = reduce_fn_(TensorColReduce(m + m_b, 1), child_0_result); + } + else { + auto shape = reduce_buffer_.extent(); + auto TensorColReduce = cute::make_tensor( + reduce_buffer_.host_data(), + cute::make_layout( + cute::make_shape(extent_m_, extent_n_, extent_l_), + cute::make_stride(1, extent_m_, extent_m_ * extent_l_) + ) + ); + TensorColReduce(m+m_b, (n+n_b)/TileN, l) = reduce_fn_(TensorColReduce(m+m_b, (n+n_b)/TileN, l), child_0_result); + } + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + // Verify the store node + tensor_column_reduce_.sync_host(); + + auto TensorColReduce = cute::make_tensor(reference_column_reduce_.host_data(), + cute::make_layout(cute::make_shape(reference_column_reduce_.size()))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(reduce_buffer_.size()))); + + // Filling the reference tensor with the reduce buffer + for (uint64_t m = 0; m < size(TensorColReduce); m ++) { + TensorColReduce(m) = destination_converter_(TensorReduceBuffer(m)); + } + + bool equal = this->equality_check(reference_column_reduce_.host_view(), tensor_column_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nColumn Reduce Reference =\n" << reference_column_reduce_.host_view() + << "\n\nColumn Reduce Computed =\n" << tensor_column_reduce_.host_view() << "\n\n"; + } + return equal; + } + + Arguments get_arguments() { + return {tensor_column_reduce_.device_data()}; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// EVT - Scalar Reduce + +template < + template class ReduceFn, + typename ElementReduce, + typename ElementCompute = float, + bool enabled = true +> +class HostScalarReduce: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + struct Arguments { + struct OpArgs { + ElementReduce* ptr_scalar = nullptr; + ElementCompute reduce_identity = 0; + cute::Stride dScalar = {}; + } op; + }; + +private: + cutlass::NumericConverter destination_converter_; + cutlass::HostTensor tensor_scalar_reduce_; + cutlass::HostTensor reduce_buffer_; + cutlass::HostTensor reference_scalar_reduce_; + ReduceFn reduce_fn_; +public: + HostScalarReduce(){} + template + HostScalarReduce(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024): + Base(check_relative_equality) { + tensor_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reference_scalar_reduce_.resize(cutlass::Coord<1>(1)); + reduce_buffer_.resize(cutlass::Coord<1>(1)); + + tensor_scalar_reduce_.sync_device(); + cutlass::reference::host::TensorFill(reduce_buffer_.host_view()); + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc, ElementCompute child_0_result) { + auto TensorRowReduce = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + TensorRowReduce(0) = reduce_fn_(TensorRowReduce(0), child_0_result); + return child_0_result; + } + + bool compare_reference(std::stringstream& error_ss) { + if constexpr (enabled) { + // Verify the store node + tensor_scalar_reduce_.sync_host(); + + auto TensorRowReduce = cute::make_tensor(reference_scalar_reduce_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + auto TensorReduceBuffer = cute::make_tensor(reduce_buffer_.host_data(), + cute::make_layout(cute::make_shape(cute::_1{}))); + + // Filling the reference tensor with the reduce buffer + TensorRowReduce(0) = destination_converter_(TensorReduceBuffer(0)); + + bool equal = this->equality_check(reference_scalar_reduce_.host_view(), tensor_scalar_reduce_.host_view()); + if (!equal) { + error_ss + << "\n\nScalar Reduce Reference =\n" << reference_scalar_reduce_.host_view() + << "\n\nScalar Reduce Computed =\n" << tensor_scalar_reduce_.host_view() << "\n\n"; + } + return equal; + } + else { + return true; + } + + } + + Arguments get_arguments() { + return {tensor_scalar_reduce_.device_data()}; + } + + auto get_flatten_arguments() { + return cute::make_tuple(tensor_scalar_reduce_.device_data()); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Host EVT wrapper + +/// The ArgumentPack is used to model the alignment when num ops <= 4 +template +struct ArgumentPack; + +template +struct ArgumentPack { + T arg; + ArgumentPack(T first): + arg(first) {} +}; + +template +struct ArgumentPack { + First arg; + ArgumentPack rest_args; + + ArgumentPack(First first, Rest... rest) : + arg(first), rest_args(rest...) {} +}; + + +/// Base class for Host Visitor +template +struct HostVisitorBase: public HostEVTNodeBase { +public: + using Base = HostEVTNodeBase; + + using Arguments_struct = ArgumentPack; + using Arguments_tuple = cute::tuple; + + constexpr static int Rm1 = sizeof...(Ops); + constexpr static bool cond = Rm1 > 4; + using Arguments = cute::conditional_t; + + std::tuple ops; + + HostVisitorBase(){} + template + HostVisitorBase(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(check_relative_equality), + ops(test::gemm::device::tapply(std::tuple{}, + [&] (auto&& op) { + using Op = cute::remove_cvref_t; + return Op(problem_size, check_relative_equality, seed); + }, + [] (auto&&... _ops) { + return std::make_tuple(_ops...); + }, + cute::make_seq{} + )){ } + + bool compare_reference(std::stringstream& error_ss) { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.compare_reference(error_ss); + }, + [&] (auto&&... inputs) { + return arrayAnd(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_C_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_C_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + void* get_tensor_D_ptr() { + return cute::detail::tapply(ops, + [&](auto& op) { + return op.get_tensor_D_ptr(); + }, + [&] (auto&&... inputs) { + return findNonNullPtr(inputs...); + }, + cute::make_seq{} + ); + } + + Arguments get_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_arguments(); + }, + [&] (auto&&... args) { + if constexpr (Rm1 > 4) { + return cute::make_tuple(args...); + } + else { + return Arguments(args...); + } + }, + cute::make_seq{} + ); + } + + auto get_flatten_arguments() { + return test::gemm::device::tapply(ops, + [&](auto& op) { + return op.get_flatten_arguments(); + }, + [&] (auto&&... args) { + return flatten(cute::make_tuple(args...)); + }, + cute::make_seq{} + ); + } + + bool arrayAnd(bool passed) { + return passed; + } + + template + bool arrayAnd(bool first_passed, Args... passed) { + if (first_passed) { + return arrayAnd(passed...); + } + return first_passed; + } + +}; + + +/// Tree-struct visitor +template +struct HostTreeVisitor: public HostVisitorBase { +public: + using ElementCompute = typename NodeOp::Base::ElementCompute; + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm1 = sizeof...(ChildOps); + + HostTreeVisitor(){} + template + HostTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed){ } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + return cute::detail::tapply(this->ops, + [&] (auto& op) { + return op.visit(m, n, l, m_b, n_b, acc); + }, + [&] (auto&&... frg_inputs) { + return std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + }, + cute::make_seq{} + ); + } +}; + + +/// General Graph visitor +template +struct HostTopoVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + constexpr static int Rm1 = Base::Rm1; + using Arguments = typename Base::Arguments; + +private: + ElementCompute frg_outputs_[Rm1]; +public: + HostTopoVisitor(){} + template + HostTopoVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + ElementCompute visit_( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + frg_outputs_[I] = cute::transform_apply(cute::get(EdgeTuple{}), + [&] (auto&& _E) { + constexpr int e = cute::remove_cvref_t::value; + return frg_outputs_[e]; + }, + [&] (auto const&... frg_inputs) { + ElementCompute res = std::get(this->ops).visit(m, n, l, m_b, n_b, acc, frg_inputs...); + return res; + } + ); + + if constexpr (I < Rm1 - 1) { + return visit_(m, n, l, m_b, n_b, acc); + } + else { + return frg_outputs_[I]; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + return visit_(m, n, l, m_b, n_b, acc); + } + +}; + + +/// SplitTree visitor +template +struct HostSplitTreeVisitor: public HostVisitorBase { +public: + using Base = HostVisitorBase; + using Arguments = typename Base::Arguments; + + constexpr static int Rm2 = sizeof...(AuxOutTrees); + +private: + ElementCompute frg_input_; +public: + HostSplitTreeVisitor(){} + template + HostSplitTreeVisitor(ProblemShapeType problem_size, bool check_relative_equality = false, int64_t seed = 2024) + :Base(problem_size, check_relative_equality, seed) { } + + template + void visitAux( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator frag) { + std::get(this->ops).visit(m, n, l, m_b, n_b, frag); + + if constexpr (I < Rm2 - 1) { + return visitAux(m, n, l, m_b, n_b, frag); + } + else { + return; + } + } + + template + ElementCompute visit( + int64_t m, int64_t n, int64_t l, int m_b, int n_b, + ElementAccumulator acc) { + + /// Compute the input tree + frg_input_ = std::get<0>(this->ops).visit(m, n, l, m_b, n_b, acc); + + /// Compute the aux out tree + visitAux(m, n, l, m_b, n_b, frg_input_); + /// Visit the output tree + return std::get(this->ops).visit(m, n, l, m_b, n_b, frg_input_); + } +}; + +/// Universal testbed for EVT w/o smem +template +class Testbed3xEVTnoSmem { +public: + // The EVT Module to test + using EVTModule = EVT; //typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + // + // Methods + // + Testbed3xEVTnoSmem( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVTnoSmem( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view(); + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}, + DecompositionMode decomposition_mode = DecompositionMode::Heuristic, + int iterations = 20, + bool profiling = false) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order, decomposition_mode }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + {}, + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + if constexpr (FlatArgs) { + auto epilogue_args = host_reference.get_flatten_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args, sizeof(epilogue_args)); + + arguments.epilogue.ptr_C = static_cast(host_reference.get_tensor_C_ptr()); + arguments.epilogue.dC = impl_.collective_epilogue.stride_c; + + arguments.epilogue.ptr_D = static_cast(host_reference.get_tensor_D_ptr()); + arguments.epilogue.dD = impl_.collective_epilogue.stride_d; + } + else { + auto epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue, &epilogue_args, sizeof(epilogue_args)); + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +/// Universal testbed for EVT +template +class Testbed3xEVT { +public: + // The EVT Module to test + using EVTModule = typename EVT::EVTModule; + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementC = typename Kernel::ElementC; + using ElementD = typename Kernel::ElementD; + + using ProblemShapeType = typename Kernel::ProblemShape; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + + // + // Methods + // + Testbed3xEVT( + bool check_relative_equality_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_((check_relative_equality_ ? CheckEquality::RELATIVE : CheckEquality::EXACT), ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(check_relative_equality_) { } + + Testbed3xEVT( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + Testbed3xEVT( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, stride_factor_B_, stride_factor_C_, stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_), + check_relative_equality(false) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B tensor + // + impl_.initialize(problem_size); + } + // Detail Implementation + TestBedImpl impl_; + + // Whether to use relative equality checks + bool check_relative_equality; + + bool verify(ProblemShapeType problem_size, EVTModule& host_reference) { + + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto LayoutD = cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + /// Reference Kernel + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + /// Epilogue EVT + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(LayoutD) && n + n_b < cute::size<1>(LayoutD)) { + host_reference.visit(m, n, l, m_b, n_b, acc[m_b][n_b]); + } + } + } + } + } + } + + std::stringstream error_ss; + bool passed = host_reference.compare_reference(error_ss); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K + << ", Batch count = " << L << "\n\n"; + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC =\n" << impl_.collective_epilogue.tensor_C.host_view() << "\n\n"; + + file << error_ss.str(); + } + + return passed; + } + + bool run( + ProblemShapeType problem_size, + bool profiling = false, + int iterations = 20, + int splits = 1) { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the Gemm operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; + if constexpr (cute::is_same_v) { + scheduler_args = { splits }; + } + + /// Initializes data structures + /// A/B/C/D Tensor + initialize(problem_size); + + /// Initialize the epilogue arguments + EVTModule host_reference(problem_size, check_relative_equality, 2024); + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { + impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b + }, + { // Epilogue arguments + {}, // thread + static_cast(host_reference.get_tensor_C_ptr()), + impl_.collective_epilogue.stride_c, + static_cast(host_reference.get_tensor_D_ptr()), + impl_.collective_epilogue.stride_d + }, // Epilogue arguments end + hw_info, + scheduler_args + }; + + // Filling in the thread arguments + typename EVTModule::Arguments epilogue_args = host_reference.get_arguments(); + std::memcpy(&arguments.epilogue.thread, &epilogue_args.arg, sizeof(epilogue_args.arg)); + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, host_reference); + if (!passed) { + std::cout << "Error : Failed \n"; + } + + return passed; + } +}; + +template +bool TestAllEVT(bool check_relative_equality = false) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xEVT testbed(check_relative_equality); + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run(problem_size); + + if (!passed) { + return false; + } + } + } + } + + // if we do support batched GEMM, just run one test on it to save on test time + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size + ); + + if (!passed) { + return false; + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbc54ec582d88d9039968d8153cf6127a06ec274 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_ptr_array.hpp @@ -0,0 +1,2409 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Testbed for Ptr-Array and Grouped GEMM interface +*/ + +#pragma once + +#include +#include +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" +#include "cutlass/complex.h" +#include "testbed_utils.h" + +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/gemm.h" + +#include "cute/int_tuple.hpp" +#include "cute/layout.hpp" +#include "cute/numeric/int.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class ScalarLoc { + ON_HOST = 0, + ON_DEVICE = 1 +}; + +enum class VectorScale { + DISABLED = 0, + ENABLED = 1 +}; + +enum class CheckEquality { + EXACT = 0, + RELATIVE = 1 +}; + +namespace detail{ + +// Helper classes that take default data type when +// the Gemm::EpilogueOutputOp does not have ElementCompute +// and ElementScalar. +// (e.g. when Sm90TreeVisitor is used as FusionCallbacks) +template +struct ElementComputeType { + using Type = Default; +}; + +template +struct ElementComputeType> { + using Type = typename Gemm::EpilogueOutputOp::ElementCompute; +}; + +template +struct ElementScalarType { + using Type = Default; +}; + +template +struct ElementScalarType> { + using Type = typename Gemm::EpilogueOutputOp::ElementScalar; +}; + + +template +struct IsF8F6F4Kernel { + static constexpr bool value = false; +}; + +template +struct IsF8F6F4Kernel> { + static constexpr bool value = true; +}; + + +// The maximum swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !cute::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} + +template +struct IsDefaultEpilogue { + static constexpr bool value = false; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +template +struct IsDefaultEpilogue> { + static constexpr bool value = true; +}; + +// The number of splits to test. +// +// This class makes it harder to confuse the order of arguments +// of the various run(...) functions in this file. The constructor +// is explicit, so one can't just type 42 (or false, which the +// compiler unhelpfully turns into 0); one has to type Splits(42). +// Splits() picks the default number of splits, 1. +// +// The conversion-to-int operator (operator int()) MUST be explicit! +// Conversion to int MUST require static_cast. +// Otherwise, that defeats a key purpose of this class, +// which is to catch common errors of confusing the order +// of function arguments. +class Splits { +public: + Splits() = default; + + template && + !cute::is_same_v)) > + explicit Splits(IntegralNotBool splits) : splits_(splits) {} + explicit operator int() const { return splits_; } +private: + int splits_ = 1; +}; + +// The number of iterations to test. +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +// Iterations() picks the default number of iterations, 20. +class Iterations { +public: + Iterations() = default; + + template && + !cute::is_same_v)) > + explicit Iterations(IntegralNotBool iterations) : iterations_(iterations) {} + explicit operator int() const { return iterations_; } +private: + int iterations_ = 20; +}; + +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + + else if (bits_input <= 6) { + scope_max = 2; + scope_min = -2; + } + + else if (bits_input <= 8) { + + if constexpr ( + cute::is_same_v){ + scope_max = 4; + scope_min = 1; + } + else { + + scope_max = 1; + scope_min = -1; + + } + + } + else{ + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Looks at Cute Stride to check Row / Column Major +template +static constexpr bool is_row_or_col_major(){ + int stride_0 = int(cute::size<0>(Stride{})); + int stride_1 = int(cute::size<1>(Stride{})); + int depth = cute::depth(Stride{}); + return ((stride_0 == 1) || (stride_1 == 1)) && (depth == 1); +} + + +// +// Default MMA input Operands : A , B +// +template< + class ScheduleType_, + class Gemm, + class ElementA_ = typename Gemm::GemmKernel::ElementA, + class ElementB_ = typename Gemm::GemmKernel::ElementB> +struct HostCollectiveMainloop { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + cutlass::ComplexTransform TransformA = Gemm::kTransformA; + cutlass::ComplexTransform TransformB = Gemm::kTransformB; + + std::vector stride_a_host; + std::vector stride_b_host; + + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_), + check_relative_equality(check_relative_equality_) { } + + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + // for pointer array problem_shapes.groups() is 1 + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for(int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + Arguments arguments; + + if constexpr (IsGroupGemm) { + arguments + = + { + device_tensors_A.get(), stride_a_device.get(), device_tensors_B.get(), stride_b_device.get() + }; + } + else { + arguments = + { + device_tensors_A.get(), stride_a_host[0], device_tensors_B.get(), stride_b_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + + cutlass::reference::host::GettMainloopParams mainloop_params{}; + + mainloop_params.A = A; + mainloop_params.B = B; + mainloop_params.transform_A = TransformA; + mainloop_params.transform_B = TransformB; + + return mainloop_params; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view(); + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + + bool passed = true; + return passed; + } +}; + + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + // Kernel data types + using ElementA = ElementA_; + using StrideA = typename Gemm::GemmKernel::StrideA; + using InternalStrideA = typename Gemm::GemmKernel::InternalStrideA; + using ElementB = ElementB_; + using StrideB = typename Gemm::GemmKernel::StrideB; + using InternalStrideB = typename Gemm::GemmKernel::InternalStrideB; + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using LayoutTagA = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; + + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + + using ElementSF = typename Gemm::GemmKernel::CollectiveMainloop::ElementSF; + using Sm1xxBlkScaledConfig = typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; + using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlkScaledConfig::Blk_SF; + using SfAtom = typename Sm1xxBlkScaledConfig::SfAtom; + using LayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFA; + using InternalLayoutSFA = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::LayoutSFB; + using InternalLayoutSFB = typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + + using Arguments = typename Gemm::GemmKernel::MainloopArguments; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + + std::vector stride_a_host; + std::vector stride_b_host; + cutlass::DeviceAllocation stride_a_device; + cutlass::DeviceAllocation stride_b_device; + + std::vector layout_sfa_host; + std::vector layout_sfb_host; + cutlass::DeviceAllocation layout_sfa_device; + cutlass::DeviceAllocation layout_sfb_device; + + typename LayoutTagA::Stride stride_factor_A; + typename LayoutTagB::Stride stride_factor_B; + + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + + std::vector> tensors_A; + std::vector> tensors_B; + std::vector> tensors_SFA; + std::vector> tensors_SFB; + + cutlass::DeviceAllocation device_tensors_A; + cutlass::DeviceAllocation device_tensors_B; + cutlass::DeviceAllocation device_tensors_SFA; + cutlass::DeviceAllocation device_tensors_SFB; + + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + // Note: this limitation comes from testbed / not the library + static_assert(is_row_or_col_major(), + "ERROR : A Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : B Layout is neither Row / Column Major)"); + + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed, + typename LayoutTagA::Stride stride_factor_A_ = typename LayoutTagA::Stride(), + typename LayoutTagB::Stride stride_factor_B_ = typename LayoutTagB::Stride() + ): + check_relative_equality(check_relative_equality_), + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + init_A(init_A_), init_B(init_B_), seed(seed_) { } + + template + bool initialize(ProblemShapeType problem_shapes) { + // + // Allocate the GEMM workspace + // + + tensors_A.clear(); + tensors_B.clear(); + stride_a_host.clear(); + stride_b_host.clear(); + tensors_SFA.clear(); + tensors_SFB.clear(); + layout_sfa_host.clear(); + layout_sfb_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_a_host.push_back(cutlass::make_cute_packed_stride(InternalStrideA{}, {M, K, 1})); + stride_b_host.push_back(cutlass::make_cute_packed_stride(InternalStrideB{}, {N, K, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto a_coord = cutlass::make_Coord(M, K); + // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // so the HostTensorB should be treated as KxN in "coord"'s view + auto b_coord = cutlass::make_Coord(K, N); + + tensors_A.push_back(cutlass::HostTensor(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A))); + tensors_B.push_back(cutlass::HostTensor(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_A[i].host_view(), init_A, seed + 2022 + i)); + EXPECT_TRUE(initialize_tensor(tensors_B[i].host_view(), init_B, seed + 2021 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_A[i].host_view().at({0, 0}) = ElementA(1); + tensors_B[i].host_view().at({0, 0}) = ElementB(1); + + tensors_A[i].sync_device(); + tensors_B[i].sync_device(); + + using namespace cute; + + auto k_blks = cutlass::ceil_div(K, size<1>(shape(SfAtom{}))); + auto m_blks = cutlass::ceil_div(M, Blk_MN{}); + auto n_blks = cutlass::ceil_div(N, Blk_MN{}); + layout_sfa_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1))); + layout_sfb_host.push_back(Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1))); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto sfa_coord = cutlass::make_Coord(m_blks * Blk_MN{}, k_blks * Blk_SF{}); + auto sfb_coord = cutlass::make_Coord(n_blks * Blk_MN{}, k_blks * Blk_SF{}); + + tensors_SFA.push_back(cutlass::HostTensor(sfa_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfa_coord, stride_factor_A))); + tensors_SFB.push_back(cutlass::HostTensor(sfb_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfb_coord, stride_factor_B))); + + EXPECT_TRUE(initialize_tensor(tensors_SFA[i].host_view(), init_A, seed + 2024 + i)); + EXPECT_TRUE(initialize_tensor(tensors_SFB[i].host_view(), init_B, seed + 2025 + i)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensors_SFA[i].host_view().at({0, 0}) = ElementSF(1); + tensors_SFB[i].host_view().at({0, 0}) = ElementSF(1); + + tensors_SFA[i].sync_device(); + tensors_SFB[i].sync_device(); + } + + return true; + } + + Arguments to_args(ProblemShapeType problem_shapes) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_A_host(L); + std::vector ptr_B_host(L); + std::vector ptr_SFA_host(L); + std::vector ptr_SFB_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_A_host.at(i) = tensors_A[i].device_data(); + ptr_B_host.at(i) = tensors_B[i].device_data(); + ptr_SFA_host.at(i) = tensors_SFA[i].device_data(); + ptr_SFB_host.at(i) = tensors_SFB[i].device_data(); + } + + device_tensors_A.reset(L); + device_tensors_A.copy_from_host(ptr_A_host.data()); + + device_tensors_B.reset(L); + device_tensors_B.copy_from_host(ptr_B_host.data()); + + device_tensors_SFA.reset(L); + device_tensors_SFA.copy_from_host(ptr_SFA_host.data()); + + device_tensors_SFB.reset(L); + device_tensors_SFB.copy_from_host(ptr_SFB_host.data()); + + stride_a_device.reset(problem_shapes.groups()); + stride_a_device.copy_from_host(stride_a_host.data()); + + stride_b_device.reset(problem_shapes.groups()); + stride_b_device.copy_from_host(stride_b_host.data()); + + layout_sfa_device.reset(problem_shapes.groups()); + layout_sfa_device.copy_from_host(layout_sfa_host.data()); + + layout_sfb_device.reset(problem_shapes.groups()); + layout_sfb_device.copy_from_host(layout_sfb_host.data()); + + if constexpr (IsGroupGemm) { + return Arguments{ + device_tensors_A.get(), stride_a_device.get(), + device_tensors_B.get(), stride_b_device.get(), + device_tensors_SFA.get(), layout_sfa_device.get(), + device_tensors_SFB.get(), layout_sfb_device.get() + }; + } + else { + return Arguments{ + device_tensors_A.get(), stride_a_host[0], + device_tensors_B.get(), stride_b_host[0], + device_tensors_SFA.get(), layout_sfa_host[0], + device_tensors_SFB.get(), layout_sfb_host[0] + }; + } + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto A = make_tensor(make_iterator(tensors_A[batch].host_data()), + make_layout(make_shape(M, K, 1), stride_a_host[batch])); + auto SfA = make_tensor(tensors_SFA[batch].host_data(), layout_sfa_host[batch]); + + auto B = make_tensor(make_iterator(tensors_B[batch].host_data()), + make_layout(make_shape(N, K, 1), stride_b_host[batch])); + auto SfB = make_tensor(tensors_SFB[batch].host_data(), layout_sfb_host[batch]); + + return cutlass::reference::host::GettMainloopParams + {A, SfA, B, SfB}; + } + + void print_tensors(std::ofstream& file, int batch) { + file << "A =\n" << tensors_A[batch].host_view() + << "\nB =\n" << tensors_B[batch].host_view() + << "\nSFA =\n" << tensors_SFA[batch].host_view() + << "\nSFB =\n" << tensors_SFB[batch].host_view(); + } + + bool compare_reference( + ProblemShapeType problem_shapes, int batch) { + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_A[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_B[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFA[batch].host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_SFB[batch].host_view()), 0); + return true; + } +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +// +// Block Scaled Gemm Input Operands : A , B, scalefactorA, scalefactorB +// +template< + class Gemm, + int SchedulerPipelineStageCount_, + int AccumulatorPipelineStageCount_, + class ElementA_, + class ElementB_ +> +struct HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> : public + HostCollectiveMainloop, + Gemm, ElementA_, ElementB_> { + using Base = HostCollectiveMainloop, + Gemm, ElementA_, ElementB_>; + HostCollectiveMainloop( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + uint64_t seed_ = Base::kDefaultSeed, + typename Base::LayoutTagA::Stride stride_factor_A_ = typename Base::LayoutTagA::Stride(), + typename Base::LayoutTagB::Stride stride_factor_B_ = typename Base::LayoutTagB::Stride() + ) : Base::HostCollectiveMainloop(check_relative_equality_, init_A_, init_B_, seed_, stride_factor_A_, stride_factor_B_) {} +}; + +template +struct HostCollectiveDefaultEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + using FusionOp = typename Gemm::EpilogueOutputOp; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + ElementScalar alpha; + ElementScalar beta; + + std::vector> tensors_C; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_C; + cutlass::DeviceAllocation device_tensors_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveDefaultEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode + auto c_coord = cutlass::make_Coord(M, N); + + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + alpha = alpha_; + beta = beta_; + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = cutlass::platform::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {alpha, beta}, + device_tensors_C.get(), stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + L = std::max(problem_shapes.groups(), L); + + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D)> + epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha; + epilogue_params.beta = beta; + + return epilogue_params; + } +}; + +template +struct HostCollectiveEpilogue { + // fusion types are potentially void if the fusion is not supported + // helper so we don't try to construct HostTensor with void type + template + using non_void_t = cute::conditional_t, U, T>; + + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + using kernel = typename Gemm::GemmKernel; + using Epilogue = typename kernel::CollectiveEpilogue; + static_assert(IsDefaultEpilogue::value == false, "Default Epilogue is not supported"); + + using ElementD = typename kernel::ElementD; + using StrideD = typename kernel::StrideD; + using InternalStrideD = typename kernel::InternalStrideD; + using ElementC = non_void_t; + using StrideC = typename kernel::StrideC; + using InternalStrideC = typename kernel::InternalStrideC; + + static constexpr bool IsGroupGemm = !cute::is_same_v; + + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static_assert(is_row_or_col_major(), + "ERROR : C Layout is neither Row / Column Major)"); + static_assert(is_row_or_col_major(), + "ERROR : D Layout is neither Row / Column Major)"); + + // Deduce Cutlass Layouts (RowMajor & ColumnMajor) + using LayoutTagC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::detail::StrideToLayoutTagC_t; + using LayoutTagScalar = cutlass::layout::PackedVectorLayout; // scalars are size-1 vectors + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + using ElementAccumulator = typename kernel::ElementAccumulator; + using ElementScalingFactor = ElementAccumulator; + using ProblemShapeType = typename kernel::ProblemShape; + + // + // FusionOperation derived types/queries + // + using EpiloguePolicy = typename Epilogue::DispatchPolicy; + static constexpr bool IsLegacy = + cute::is_same_v< + EpiloguePolicy, + cutlass::epilogue::Sm90TmaWarpSpecializedBiasElementwise< + EpiloguePolicy::StagesC, EpiloguePolicy::StagesD, EpiloguePolicy::FragmentSize> + >; + + using FusionOp = typename Gemm::EpilogueOutputOp; + static_assert(cute::is_base_of_v); + + + // Scale factor Generation related + using SfStrategy = cutlass::reference::host::SfStrategy; + static constexpr bool IsBlockScaleSupported = FusionOp::IsBlockScaleSupported; + static constexpr SfStrategy SfGenStrategy = (!IsBlockScaleSupported) ? SfStrategy::None : SfStrategy::SfDGen; + static constexpr int32_t SFD_VectorSize = IsBlockScaleSupported ? FusionOp::SFVecSize : 1; + using ElementSFD = non_void_t, ElementD>; + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + SFD_VectorSize + >; + using Blk_MN = typename Sm1xxBlockScaledOutputConfig::Blk_MN; + using Blk_SF = typename Sm1xxBlockScaledOutputConfig::Blk_SF; + using OutputSFAtom = typename Sm1xxBlockScaledOutputConfig::SfAtom; + std::vector> tensors_SFD; + std::vector> references_SFD; + cutlass::DeviceAllocation device_tensors_SFD; + + using ElementCompute = typename FusionOp::ElementCompute; + using ElementScalar = typename FusionOp::ElementScalar; + using ElementBias = non_void_t; + using ElementAux = non_void_t; + using ElementAmax = non_void_t; + using LayoutTagAux = non_void_t; + using ActivationFunctor = non_void_t>; + + static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; + static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; + static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + + using Arguments = typename Gemm::GemmKernel::EpilogueArguments; + + /// Initialization + cutlass::DeviceAllocation stride_c_device; + cutlass::DeviceAllocation stride_d_device; + + std::vector stride_c_host; + std::vector stride_d_host; + + typename LayoutTagC::Stride stride_factor_C; + typename LayoutTagD::Stride stride_factor_D; + + // Inputs + cutlass::HostTensor alpha; + cutlass::HostTensor beta; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor bias; + std::vector> tensors_C; + cutlass::DeviceAllocation device_tensors_C; + cutlass::HostTensor norm_constant; + + // Outputs + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + std::vector> tensors_Aux; + cutlass::DeviceAllocation device_tensors_Aux; + cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; + std::vector> tensors_D; + std::vector> references_D; + cutlass::DeviceAllocation device_tensors_D; + + // References + cutlass::HostTensor reference_dbias; + std::vector> references_Aux; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // Whether to use relative equality checks + CheckEquality check_relative_equality = CheckEquality::EXACT; + // Are scalars copied to device memory before kernel launch + ScalarLoc use_device_scalars = ScalarLoc::ON_HOST; + // If per-row scale is enabled and this is disabled, alpha/beta are passed as a host or device scalar instead of device vector + VectorScale vector_scale_mode = VectorScale::DISABLED; + + // Random distribution with which to initialize the A/B/C/D/Aux scaling factors + cutlass::Distribution::Kind init_scale = cutlass::Distribution::Uniform; + // Random distribution with which to initialize the bias vector + cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; + cutlass::Distribution::Kind init_C; + uint64_t seed; + static constexpr uint64_t kDefaultSeed = 4096; + + HostCollectiveEpilogue( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): init_scale(init_scale_), init_bias(init_bias_), + init_C(init_C_), seed(seed_), + stride_factor_C(typename LayoutTagC::Stride()), + stride_factor_D(typename LayoutTagD::Stride()), + check_relative_equality(check_relative_equality_), + use_device_scalars(use_device_scalars_){ } + + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + // Initialize Epilogue tensors + + tensors_C.clear(); + tensors_D.clear(); + references_D.clear(); + stride_c_host.clear(); + stride_d_host.clear(); + + tensors_SFD.clear(); + references_SFD.clear(); + + + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + + stride_c_host.push_back(cutlass::make_cute_packed_stride(InternalStrideC{}, {M, N, 1})); + stride_d_host.push_back(cutlass::make_cute_packed_stride(InternalStrideD{}, {M, N, 1})); + + auto c_coord = cutlass::make_Coord(M, N); + tensors_C.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_C))); + tensors_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D))); + references_D.push_back(cutlass::HostTensor(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, stride_factor_D), false)); + EXPECT_TRUE(initialize_tensor(tensors_C[i].host_view(), init_C, seed + 2020)); + tensors_C[i].host_view().at({0, 0}) = ElementC(1); + + cutlass::reference::host::TensorCopy(references_D[i].host_view(), tensors_C[i].host_view()); + tensors_C[i].sync_device(); + tensors_D[i].sync_device(); + } + + auto scalar_coord = cutlass::make_Coord(1); + auto col_vector_coord = cutlass::make_Coord(M); + if constexpr (IsPerRowScaleEnabled) { + alpha.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(alpha.host_view(), init_scale, seed + 2023)); + if (vector_scale_mode == VectorScale::DISABLED) { + beta.resize(scalar_coord, false); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + else { + beta.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(beta.host_view(), init_scale, seed + 2024)); + } + } + else { + alpha.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + beta.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + cutlass::reference::host::TensorFill(alpha.host_view(), alpha_); + cutlass::reference::host::TensorFill(beta.host_view(), beta_); + } + alpha.sync_device(); + beta.sync_device(); + + if constexpr (IsScaleFactorEnabled) { + scale_A.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_B.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_C.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + scale_D.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_A.host_view(), init_scale, seed + 2023)); + EXPECT_TRUE(initialize_tensor(scale_B.host_view(), init_scale, seed + 2024)); + EXPECT_TRUE(initialize_tensor(scale_C.host_view(), init_scale, seed + 2025)); + EXPECT_TRUE(initialize_tensor(scale_D.host_view(), init_scale, seed + 2026)); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + } + + if constexpr (IsBiasEnabled) { + bias.resize(col_vector_coord); + EXPECT_TRUE(initialize_tensor(bias.host_view(), init_bias, seed + 2023)); + bias.sync_device(); + } + + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { + abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_D.sync_device(); + reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); + } + + tensors_Aux.clear(); + references_Aux.clear(); + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxInEnabled)); + + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + for (int32_t i = 0; i < L; ++i) { + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + EXPECT_TRUE(initialize_tensor(tensors_Aux[i].host_view(), init_C, seed + 2023)); + tensors_Aux[i].sync_device(); + } + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + } + + static_assert(!IsGroupGemm or (IsGroupGemm and !IsAuxOutEnabled)); + + if constexpr (IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, mock_L] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + auto aux_coord = cutlass::make_Coord(M, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensors_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout)); + references_Aux.push_back(cutlass::HostTensor(aux_coord, aux_layout, false)); + tensors_Aux[i].sync_device(); + } + + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, 1)); + + if constexpr (IsScaleFactorEnabled) { + scale_Aux.resize(scalar_coord, (use_device_scalars == ScalarLoc::ON_DEVICE)); + EXPECT_TRUE(initialize_tensor(scale_Aux.host_view(), init_scale, seed + 2027)); + scale_Aux.sync_device(); + } + + if constexpr (IsAbsMaxEnabledAux) { + abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); + abs_max_Aux.sync_device(); + reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); + } + } + + + if constexpr (IsBlockScaleSupported) { + for (int32_t i = 0; i < L; ++i) { + auto [M, N, K, _] = cute::append<4>(problem_shapes.get_host_problem_shape(i), 1); + // If block scaled output is supported we always have at least 1 SFD + auto m_blks = cutlass::ceil_div(M, cute::size<0>(cute::shape(OutputSFAtom{}))); + auto n_blks = cutlass::ceil_div(N, cute::size<1>(cute::shape(OutputSFAtom{}))); + auto sfd_coord = [&] () { + return cutlass::make_Coord(m_blks * Blk_MN{}, n_blks * Blk_SF{}); + }(); + tensors_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D))); + references_SFD.push_back(cutlass::HostTensor(sfd_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(sfd_coord, stride_factor_D), false)); + tensors_SFD[i].sync_device(); + } + norm_constant.resize(scalar_coord, true); + EXPECT_TRUE(initialize_tensor(norm_constant.host_view(), init_scale, seed + 2023)); + norm_constant.sync_device(); + } + + + return true; + } + + template < + class Element, + class Layout + > + bool equality_check( + cutlass::TensorView const& lhs, + cutlass::TensorView const& rhs) const { + + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(static_cast(0.1f)); + Element nonzero_floor(std::numeric_limits::min()); + + if constexpr (!cutlass::is_complex::value) { + if (check_relative_equality == CheckEquality::RELATIVE) { + return cutlass::reference::host::TensorRelativelyEquals( + lhs, rhs, epsilon, nonzero_floor); + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + else { + return cutlass::reference::host::TensorEquals(lhs, rhs); + } + } + + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) { + tensors_D[batch].sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_C[batch].host_view()), 0); + + if (tensors_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensors_D[batch].host_view()), 0); + } + + if (references_D[batch].size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(references_D[batch].host_view()), 0); + } + + bool passed = equality_check(references_D[batch].host_view(), tensors_D[batch].host_view()); + if(!passed) { + std::cout<<"D is incorrect"<(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + std::vector ptr_C_host(L); + std::vector ptr_D_host(L); + + for (int32_t i = 0; i < L; ++i) { + ptr_C_host.at(i) = tensors_C[i].device_data(); + ptr_D_host.at(i) = tensors_D[i].device_data(); + } + + device_tensors_C.reset(L); + device_tensors_C.copy_from_host(ptr_C_host.data()); + + device_tensors_D.reset(L); + device_tensors_D.copy_from_host(ptr_D_host.data()); + + stride_c_device.reset(problem_shapes.groups()); + stride_c_device.copy_from_host(stride_c_host.data()); + + stride_d_device.reset(problem_shapes.groups()); + stride_d_device.copy_from_host(stride_d_host.data()); + + std::vector ptr_Aux_host(L); + if constexpr (IsAuxInEnabled || IsAuxOutEnabled) { + for (int32_t i = 0; i < L; ++i) { + ptr_Aux_host.at(i) = tensors_Aux[i].device_data(); + } + device_tensors_Aux.reset(L); + device_tensors_Aux.copy_from_host(ptr_Aux_host.data()); + } + + auto device_tensors_C_ptr = cute::is_void_v ? nullptr : + reinterpret_cast(device_tensors_C.get()); + + Arguments arguments; + if constexpr (IsGroupGemm) { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_device.get(), device_tensors_D.get(), stride_d_device.get() + }; + } + else { + arguments = + { + {}, + device_tensors_C_ptr, stride_c_host[0], device_tensors_D.get(), stride_d_host[0] + }; + } + + auto &fusion_args = arguments.thread; + if constexpr (IsLegacy) { + arguments.thread = { + alpha.at(coord_0), + beta.at(coord_0), + alpha.device_data(), + beta.device_data() + }; + arguments.ptr_Bias = bias.device_data(); + arguments.ptr_T = device_tensors_Aux.get(); + } + else { + fusion_args.alpha = alpha.at(coord_0); + fusion_args.beta = beta.at(coord_0); + + fusion_args.alpha_ptr = alpha.device_data(); + // can_implement requires beta_ptr to not be set if its voidC + fusion_args.beta_ptr = cute::is_void_v ? nullptr : + beta.device_data(); + + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_a = scale_A.at(coord_0); + fusion_args.scale_b = scale_B.at(coord_0); + fusion_args.scale_c = scale_C.at(coord_0); + fusion_args.scale_d = scale_D.at(coord_0); + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + fusion_args.scale_d_ptr = scale_D.device_data(); + } + + if constexpr (IsBiasEnabled) { + fusion_args.bias_ptr = bias.device_data(); + } + + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + + // example of how to set kernel activation arguments + // see ActivationFunctor::Arguments in activation.h for definition + // if Arguments doesn't exist then fusion_args.activation is empty + if constexpr (cute::is_same_v>) { + fusion_args.activation.scale = ElementCompute(1); + } + + // Treat Clamp as ReLU + if constexpr (cute::is_same_v>) { + fusion_args.activation.lower_bound = 0; + fusion_args.activation.upper_bound = std::numeric_limits::max(); + } + + if constexpr (IsAbsMaxEnabledD) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { + fusion_args.aux_ptr = device_tensors_Aux.get(); + fusion_args.dAux = stride_Aux; + if constexpr (IsScaleFactorEnabled) { + fusion_args.scale_aux = scale_Aux.at(coord_0); + fusion_args.scale_aux_ptr = scale_Aux.device_data(); + } + if constexpr (IsAbsMaxEnabledAux) { + fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); + } + } + + if constexpr (IsBlockScaleSupported) { + std::vector ptr_SFD_host(L); + for (int32_t i = 0; i < L; ++i) { + ptr_SFD_host.at(i) = tensors_SFD[i].device_data(); + } + device_tensors_SFD.reset(L); + device_tensors_SFD.copy_from_host(ptr_SFD_host.data()); + + arguments.thread.block_scale_factor_ptr = device_tensors_SFD.get(); + arguments.thread.norm_constant_ptr = norm_constant.device_data(); + } + + } + + return arguments; + } + + auto to_host_args(ProblemShapeType problem_shapes, int batch) { + using namespace cute; + // + // Allocate the GEMM workspace + // + auto problem_shape_MNKL = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + auto [M, N, K, L] = problem_shape_MNKL; + auto coord_0 = cutlass::make_Coord(0); + auto C = cute::make_tensor(detail::make_iterator(tensors_C[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_c_host[batch])); + auto D = cute::make_tensor(detail::make_iterator(references_D[batch].host_data()), + cute::make_layout(cute::make_shape(M, N, 1), stride_d_host[batch])); + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), + cute::make_layout(cute::make_shape(M, cute::_1{}))); + auto Aux_layout = cute::make_layout(cute::make_shape(M, N, 1), stride_Aux); + auto Aux = [&]() { + auto ptr = recast_ptr(nullptr); + if (IsAuxInEnabled) { + ptr = detail::make_iterator(tensors_Aux[batch].host_data()); + } else if (IsAuxOutEnabled) { + ptr = detail::make_iterator(references_Aux[batch].host_data()); + } + return cute::make_tensor(ptr, Aux_layout); + }(); + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), + cute::make_layout(cute::make_shape(M, N, cute::_1{}), cute::make_stride(cute::_1{}, cute::_0{}, N))); + + auto SfD = [&](){ + if constexpr (IsBlockScaleSupported) { + auto tensor = make_tensor(detail::make_iterator(references_SFD[batch].host_data()), + Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + return tensor; + } + else { + // Reference kernel has a logic to ignore scalefactor computation if we pass the tensor type same as output D tensor. + return D; + } + }(); + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + decltype(Bias), + decltype(Aux), + decltype(Valpha), + decltype(Vbeta), + ActivationFunctor + , decltype(SfD) + , Int + , cutlass::plus + , false + , SfGenStrategy + > epilogue_params{}; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = alpha.at(coord_0); + epilogue_params.beta = beta.at(coord_0); + + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_a = scale_A.at(coord_0); + epilogue_params.scale_b = scale_B.at(coord_0); + epilogue_params.scale_c = scale_C.at(coord_0); + epilogue_params.scale_d = scale_D.at(coord_0); + } + + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { + epilogue_params.Bias = Bias; + } + + if constexpr (IsAbsMaxEnabledD) { + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + } + + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { + epilogue_params.Aux = Aux; + if constexpr (IsScaleFactorEnabled) { + epilogue_params.scale_aux = scale_Aux.at(coord_0); + } + if constexpr (IsAbsMaxEnabledAux) { + epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); + } + } + + if constexpr (IsPerRowScaleEnabled) { + epilogue_params.Valpha = Valpha; + if (vector_scale_mode == VectorScale::ENABLED) { + epilogue_params.Vbeta = Vbeta; + } + } + + if constexpr (IsBlockScaleSupported) { + epilogue_params.SfD = SfD; + epilogue_params.st = norm_constant.at(coord_0); + } + + return epilogue_params; + } +}; + +template < + typename Gemm, + template class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct TestbedImpl { + // Kernel data types + using ScheduleType = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy::Schedule; + // All Collective MMA operands are defined by HostCollectiveMainloopType based on the schedule type + using HostCollectiveMainloopType = HostCollectiveMainloop; + using CollectiveEpilogue = cute::conditional_t::value || force_legacy_epilogue, + HostCollectiveDefaultEpilogue, + HostCollectiveEpilogue>; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementAccumulator = typename Gemm::GemmKernel::ElementAccumulator; + using ElementCompute = typename ElementComputeType::Type; + using ElementScalar = typename ElementScalarType::Type; + + using LayoutTagA = typename HostCollectiveMainloopType::LayoutTagA; + using LayoutTagB = typename HostCollectiveMainloopType::LayoutTagB; + using LayoutTagC = typename CollectiveEpilogue::LayoutTagC; + using LayoutTagD = typename CollectiveEpilogue::LayoutTagD; + + uint32_t sm_count; + // Used to force multi-wave tests for persistent kernel schedules + constexpr static int MaxSmCount = 16; + static constexpr uint64_t kDefaultSeed = 4096; + static constexpr uint32_t mma_promotion_interval = 4; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + HostCollectiveMainloopType collective_mma_inputs; + CollectiveEpilogue collective_epilogue; + + static constexpr bool IsGroupGemm = CollectiveEpilogue::IsGroupGemm; + + // + // Methods + // + + TestbedImpl( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + TestbedImpl( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_HOST, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = kDefaultSeed + ): collective_mma_inputs(HostCollectiveMainloopType(check_relative_equality_, stride_factor_A_, stride_factor_B_, init_A_, init_B_, seed_)), + collective_epilogue(CollectiveEpilogue(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_C_, init_scale_, init_bias_, seed_)) { } + + /// Initializes data structures + bool initialize(ProblemShapeType problem_shapes, ElementScalar alpha_=1.f, ElementScalar beta_=0.f) { + collective_mma_inputs.initialize(problem_shapes); + collective_epilogue.initialize(problem_shapes, alpha_, beta_); + + return true; + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta, + int batch) + { + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(batch), 1); + + bool passed = collective_mma_inputs.compare_reference(problem_shapes, batch); + passed &= collective_epilogue.compare_reference(problem_shapes, alpha, beta, batch); + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_" + << M << "x" << N << "x" << K << "x" << batch << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << batch + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + collective_mma_inputs.print_tensors(file, batch); + collective_epilogue.print_tensors(file, batch); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + ProblemShapeType problem_shapes, + ElementScalar alpha, + ElementScalar beta) + { + using namespace cute; + auto [M, N, K, L] = cute::append<4>(problem_shapes.get_host_problem_shape(0), 1); + L = std::max(problem_shapes.groups(), L); + + bool passed = true; + for (int32_t i = 0; i < L; ++i) { + auto mainloop_params = collective_mma_inputs.to_host_args(problem_shapes, i); + auto epilogue_params = collective_epilogue.to_host_args(problem_shapes, i); + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + passed &= compare_reference(problem_shapes, alpha, beta, i); + } + return passed; + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = static_cast(Gemm::GemmKernel::SharedStorageSize); + + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + this->sm_count = properties.multiProcessorCount; + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + printf("failed due to smem_size\n"); + printf("hardware smem_size: %d, required smem_size: %d\n\n", int(properties.sharedMemPerBlockOptin), int(smem_size)); + return false; + } + + return true; + } + + /// Executes one test + bool run( + ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + + // Fail test if insufficient CUDA device + if (!sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + + if (!this->initialize(problem_shapes, alpha, beta)) { + std::cerr << "Initialization failed \n"; + return false; + } + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + this->sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = this->sm_count; + + typename HostCollectiveMainloopType::Arguments mainloop_args; + + mainloop_args = collective_mma_inputs.to_args(problem_shapes); + + if constexpr (IsGroupGemm) { + arguments = + { + cutlass::gemm::GemmUniversalMode::kGrouped, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + else { + arguments = + { + cutlass::gemm::GemmUniversalMode::kArray, + problem_shapes, + mainloop_args, + collective_epilogue.to_args(problem_shapes), + hw_info + }; + } + + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return false; + } + + // + // Run the GEMM + // + + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_shapes, alpha, beta); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta + << "\n"; + } + + return passed; + } +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity, + bool force_legacy_epilogue = false, + typename ElementA = typename Gemm::GemmKernel::ElementA, + typename ElementB = typename Gemm::GemmKernel::ElementB +> +struct Testbed3x { + + using TestBedImpl = typename detail::TestbedImpl< + Gemm, + ActivationFunctor, + force_legacy_epilogue, + ElementA, + ElementB + >; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementAccumulator = typename TestBedImpl::ElementAccumulator; + using ElementCompute = typename TestBedImpl::ElementCompute; + using ElementScalar = typename TestBedImpl::ElementScalar; + + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + using DecompositionMode = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90StreamKParams::DecompositionMode; + + static constexpr bool IsGroupGemm = TestBedImpl::IsGroupGemm; + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3x( + CheckEquality check_relative_equality_ = CheckEquality::EXACT, + ScalarLoc use_device_scalars_ = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode_ = VectorScale::DISABLED, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed) + : impl_(check_relative_equality_, use_device_scalars_, vector_scale_mode_, init_A_, init_B_, init_C_, init_scale_, init_bias_, seed_) {} + + /// Executes one test + bool run( + typename TestBedImpl::ProblemShapeType problem_shapes, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + detail::Iterations iterations = detail::Iterations{} + ) + { + return impl_.run( + problem_shapes, alpha, beta, iterations); + } +}; + +template < + typename Gemm, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAll(double alpha = 1.0, double beta = 0.0, CheckEquality check_relative_equality = CheckEquality::RELATIVE) { + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + Testbed3x testbed(check_relative_equality, ScalarLoc::ON_DEVICE, VectorScale::DISABLED); + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + int batches[] = {5, 10}; + + bool passed = true; + + for (int batch : batches) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 3) + 1), n * ((i % 4) + 1), k * ((i % 5) + 1)}); + } + + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + passed = testbed.run( + ProblemShapeType{static_cast(problem_sizes_host.size()), problem_sizes_device.get(), problem_sizes_host.data()}, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_size{{m, n, k, batch}}; + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNKL " << m << " " << n << " " << k << " " << batch << " FAILED.\n"; + return false; + } + } // k + } // n + } // m + } // batch + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestSmall(double alpha = 1.0, double beta = 1.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED, + std::vector override_problem_size_k = {}) { + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + using ElementScalar = typename Gemm::EpilogueOutputOp::ElementScalar; + using ElementA = typename Gemm::GemmKernel::ElementA; + using ElementB = typename Gemm::GemmKernel::ElementB; + using TiledMma = typename Gemm::GemmKernel::TiledMma; + + static constexpr bool IsF8F6F4 = cutlass::gemm::collective::detail::is_sm100_mma_f8f6f4(); + // For fp4 and fp6 kernels, the min alignment_input is 128 elements, so we don't need to add alignment_input in test problem sizes. + int alignment_bits_a = cutlass::detail::get_input_alignment_bits(); + int alignment_input_a = (alignment_bits_a / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_a / cute::sizeof_bits::value); + + int alignment_bits_b = cutlass::detail::get_input_alignment_bits(); + int alignment_input_b = (alignment_bits_b / cute::sizeof_bits::value == 128) ? 0 : (alignment_bits_b / cute::sizeof_bits::value); + + int alignment_input = (alignment_input_a == 0 || alignment_input_b == 0) ? 0 : std::max(alignment_input_a, alignment_input_b); + + if constexpr (apply_alignment_offset) { + // If BlockScaled, then min alignment is SFVecSize + static constexpr bool IsBlockScaleSupported = Gemm::EpilogueOutputOp::IsBlockScaleSupported; + static constexpr int SFVecSize = Gemm::GemmKernel::CollectiveMainloop::SFVecSize; + if constexpr (IsBlockScaleSupported) { + alignment_input = cutlass::round_up(alignment_input, SFVecSize); + } + } + + + using CtaShape_MNK = typename Gemm::GemmKernel::CollectiveMainloop::CtaShape_MNK; + using DispatchPolicy = typename Gemm::GemmKernel::CollectiveMainloop::DispatchPolicy; + CtaShape_MNK cta_shape; + Testbed3x testbed(check_relative_equality, use_device_scalars, vector_scale_mode); + // For Ptr-Array and Grouped GEMM ideally we need to know SM count at runtime + static constexpr int SmCount = 16; + + float waves[] = {0.5, 2.5}; + int batches[] = {3}; + int cluster_m = 1; + int cluster_n = 1; + + std::vector problem_size_k; + if (override_problem_size_k.empty()) { + // this is to test with min alignment + problem_size_k = {256 - alignment_input, 512 + alignment_input}; + } + else { + problem_size_k = override_problem_size_k; + } + + if constexpr(DispatchPolicy::ArchTag::kMinComputeCapability >= 90) { + typename DispatchPolicy::ClusterShape cluster_shape; + cluster_m = cute::size<0>(cluster_shape); + cluster_n = cute::size<1>(cluster_shape); + } + + bool passed = true; + + for (int batch : batches) { + for (float wave : waves) { + for (int k : problem_size_k) { + int grid_m, grid_n = 0; + float num_grid = wave * SmCount; + + if (cluster_m >= cluster_n) { + grid_m = cluster_m; + grid_n = static_cast(num_grid) / grid_m; + // Align grid_n to cluster_n + grid_n = std::max((grid_n + cluster_n - 1 ) / cluster_n * cluster_n, 1); + } + else { + grid_n = cluster_n; + grid_m = static_cast(num_grid) / grid_n; + // Align grid_m to cluster_m + grid_m = std::max((grid_m + cluster_m - 1 ) / cluster_m * cluster_m, 1); + } + + int m = grid_m * cute::size<0>(cta_shape) - alignment_input; // this is just to test with unusual problem shapes + int n = grid_n * cute::size<1>(cta_shape) + alignment_input; + + if constexpr (Testbed3x::IsGroupGemm) { + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + for (int i = 0; i < batch; ++i) { + problem_sizes_host.push_back({m * ((i % 2) + 1), n * ((i % 3) + 1), k * ((i % 2) + 1)}); + } + problem_sizes_device.reset(problem_sizes_host.size()); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + ProblemShapeType problem_shapes{batch, problem_sizes_device.get(), problem_sizes_host.data()}; + + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + for (int i = 0; i < batch; ++i) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape(i) << " \n"; + } + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + else { + ProblemShapeType problem_shapes{{m, n, k, batch}}; + if (CUTLASS_DEBUG_TRACE_LEVEL > 0) { + std::cout << "problem_shapes : " << problem_shapes.get_host_problem_shape() << " \n"; + } + passed = testbed.run( + problem_shapes, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + + if (!passed) { + std::cout << __FILE__ << ':' << __LINE__ << " : GEMM MNK " << m << " " << n << " " << k << " FAILED.\n"; + return false; + } + } // k + } // waves + } // batches + + return passed; +} + +template +bool TestSmallFusion(double alpha = 1.0, double beta = 0.0, + CheckEquality check_relative_equality = CheckEquality::RELATIVE, + ScalarLoc use_device_scalars = ScalarLoc::ON_DEVICE, + VectorScale vector_scale_mode = VectorScale::ENABLED) { + return TestSmall( + alpha, beta, check_relative_equality, use_device_scalars, vector_scale_mode); +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b00f98a97846de175f1c6f95919c483ab4b81da --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/gemm_testbed_3x_tensor_broadcast.hpp @@ -0,0 +1,515 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface with elementwise tensor-tensor broadcast epilogue +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed_utils.h" +#include "gemm_testbed_3x.hpp" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed3xTensorBroadcast { + + using TestBedImpl = typename detail::TestbedImpl; + using Kernel = typename Gemm::GemmKernel; + using Epilogue = typename Gemm::GemmKernel::CollectiveEpilogue; + + using ElementA = typename Kernel::ElementA; + using StrideA = typename Kernel::StrideA; + using ElementB = typename Kernel::ElementB; + using StrideB = typename Kernel::StrideB; + using ElementC = typename Kernel::ElementC; + using StrideC = typename Kernel::StrideC; + using ElementD = typename Kernel::ElementD; + using StrideD = typename Kernel::StrideD; + + using ElementAccumulator = typename Kernel::ElementAccumulator; + using ElementCompute = typename Epilogue::ElementCompute; + using ElementScalar = typename Epilogue::ElementScalar; + using ProblemShapeType = typename Kernel::ProblemShape; + using ElementBias = typename Epilogue::ElementBias; + using ActivationFunctor = typename Epilogue::ActivationFunctor; + + static constexpr bool IsBinaryOp0Enabled = Epilogue::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = Epilogue::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = Epilogue::IsUnaryOpEnabled; + + static constexpr bool PerColBias = Epilogue::PerColumnBias; + + using LayoutTagA = typename TestBedImpl::LayoutTagA; + using LayoutTagB = typename TestBedImpl::LayoutTagB; + using LayoutTagC = typename TestBedImpl::LayoutTagC; + using LayoutTagD = typename TestBedImpl::LayoutTagD; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; + + cutlass::HostTensor bias; + cutlass::HostTensor tensor_C1; + // tensor_C0 is taken from TestbedImpl's tensor_C + + + // Detail Implementation + TestBedImpl impl_; + + // + // Methods + // + Testbed3xTensorBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(CheckEquality::EXACT, ScalarLoc::ON_DEVICE, VectorScale::ENABLED, + init_A_, init_B_, init_C_, cutlass::Distribution::Uniform, cutlass::Distribution::Uniform, seed_) { } + + Testbed3xTensorBroadcast( + typename LayoutTagA::Stride stride_factor_A_, + typename LayoutTagB::Stride stride_factor_B_, + typename LayoutTagC::Stride stride_factor_C_, + typename LayoutTagD::Stride stride_factor_D_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = TestBedImpl::kDefaultSeed + ) : + impl_(stride_factor_A_, + stride_factor_B_, + stride_factor_C_, + stride_factor_D_, + CheckEquality::EXACT, ScalarLoc::ON_HOST, VectorScale::ENABLED, + init_A_, + init_B_, + init_C_, + cutlass::Distribution::Uniform, + cutlass::Distribution::Uniform, + seed_) { } + + /// Initializes data structures + void initialize(ProblemShapeType problem_size) { + // + // Allocate the GEMM workspace for A/B/C/D tensor + // + impl_.initialize(problem_size); + } + + void initialize_bias(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto bias_size = PerColBias ? cute::get<1>(problem_shape_MNKL) : cute::get<0>(problem_shape_MNKL); + bias.resize(cutlass::Coord<1>(bias_size)); + + EXPECT_TRUE(detail::initialize_tensor(bias.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2023)); + bias.sync_device(); + } + + void initialize_c1(ProblemShapeType problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto c_coord = cutlass::make_Coord(M * L, N); + + tensor_C1.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C)); + EXPECT_TRUE(detail::initialize_tensor(tensor_C1.host_view(), cutlass::Distribution::Uniform, impl_.collective_mma_inputs.seed + 2024)); + tensor_C1.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cute::Shape problem_shape_MNKL, + ElementScalar alpha, + ElementScalar beta, + bool use_bias) + { + auto [M, N, K, L] = problem_shape_MNKL; + + impl_.collective_epilogue.tensor_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_mma_inputs.tensor_B.host_view()), 0); + + if (impl_.collective_epilogue.tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.tensor_D.host_view()), 0); + } + + if (impl_.collective_epilogue.reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(impl_.collective_epilogue.reference_D.host_view()), 0); + } + + bool passed = cutlass::reference::host::TensorEquals(impl_.collective_epilogue.reference_D.host_view(), impl_.collective_epilogue.tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::stringstream fname; + fname << "error_Gemm_device_broadcast" + << M << "x" << N << "x" << K << "x" << L << "_" + << cute::get<0>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<1>(typename Gemm::GemmKernel::TileShape{}) << "_" + << cute::get<2>(typename Gemm::GemmKernel::TileShape{}) << ".txt"; + + std::ofstream file(fname.str()); + file + << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L + << ", alpha: " << float(alpha) << ", beta: " << float(beta) << ", use_bias: " << use_bias + << ", per-col bias: " << PerColBias << "\n\n"; + + if (use_bias){ + file << "Bias = \n" << bias.host_view()<< "\n\n"; + } + + file + << "A =\n" << impl_.collective_mma_inputs.tensor_A.host_view() + << "\nB =\n" << impl_.collective_mma_inputs.tensor_B.host_view() + << "\nC0 =\n" << impl_.collective_epilogue.tensor_C.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\n\nReference =\n" << impl_.collective_epilogue.reference_D.host_view() + << "\n\nComputed =\n" <(problem_size, 1); + auto M = cute::get<0>(problem_shape_MNKL); + auto N = cute::get<1>(problem_shape_MNKL); + auto K = cute::get<2>(problem_shape_MNKL); + auto L = cute::get<3>(problem_shape_MNKL); + + auto A = cute::make_tensor(impl_.collective_mma_inputs.tensor_A.host_data(), + cute::make_layout(cute::make_shape(M, K, L), impl_.collective_mma_inputs.stride_a)); + auto B = cute::make_tensor(impl_.collective_mma_inputs.tensor_B.host_data(), + cute::make_layout(cute::make_shape(N, K, L), impl_.collective_mma_inputs.stride_b)); + auto D = cute::make_tensor(impl_.collective_epilogue.reference_D.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto Bias = cute::make_tensor(static_cast(use_bias ? bias.host_data() : nullptr), + cute::make_layout(PerColBias ? cute::make_shape(1, N) : cute::make_shape(M, 1))); + auto C0 = cute::make_tensor(impl_.collective_epilogue.tensor_C.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + auto C1 = cute::make_tensor(tensor_C1.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + // Create host workspace for output of testbed. This computes a portion of the epilogue: + // ref_compute_out = Activation(alpha * (A @ B) + bias) + cutlass::HostTensor ref_compute_out; + auto c_coord = cutlass::make_Coord(M * L, N); + ref_compute_out.resize(c_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(c_coord, impl_.collective_epilogue.stride_factor_C), false); + auto RefComputeOut = cute::make_tensor(ref_compute_out.host_data(), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + // Use a dummy null tensor for operand C because the epilogue overrides C. + auto dummy_C = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + ElementCompute dummy_beta(0); + auto dummy_Aux = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_d)); + auto dummy_Valpha = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + auto dummy_Vbeta = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, 1), cute::make_stride(cute::_1{}, cute::_0{}, M))); + + auto dummy_SFD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), impl_.collective_epilogue.stride_c)); + using DummySFDVectorSize = cute::Int<0>; + + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(dummy_C), + decltype(RefComputeOut), + decltype(Bias), + decltype(dummy_Aux), + decltype(dummy_Valpha), + decltype(dummy_Vbeta), + ActivationFunctor, + decltype(dummy_SFD), + DummySFDVectorSize, + cutlass::plus, + PerColBias> epilogue_params{ + alpha, + dummy_beta, + dummy_C, + RefComputeOut, + Bias, + dummy_Aux, + dummy_Valpha, + dummy_Vbeta + }; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + cutlass::NumericConverter source_converter; + cutlass::NumericConverter destination_converter; + cutlass::multiplies mul; + + // Compute broadcast operations atop the reference + #pragma omp parallel for collapse(3) + for (int64_t l = 0; l < cute::size<2>(A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(A.layout()); ++m) { + for (int64_t n = 0; n < cute::size<0>(B.layout()); ++n) { + ElementCompute intermediate = RefComputeOut(m, n, l); + // Apply BinaryOp0, if needed + if constexpr (IsBinaryOp0Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp0 bin0; + ElementCompute converted_source = source_converter(C0(m, n, l)); + intermediate = bin0(intermediate, mul(beta, converted_source)); + } + + // Apply BinaryOp1, if needed + if constexpr (IsBinaryOp1Enabled) { + typename Epilogue::ThreadEpilogueOp::BinaryOp1 bin1; + ElementCompute converted_source = source_converter(C1(m, n, l)); + intermediate = bin1(intermediate, mul(beta, converted_source)); + } + + // Apply UnaryOp, if needed + if constexpr (IsUnaryOpEnabled) { + typename Epilogue::ThreadEpilogueOp::UnaryOp unary; + intermediate = unary(intermediate); + } + + D(m, n, l) = destination_converter(intermediate); + } + } + } + + return compare_reference(problem_shape_MNKL, alpha, beta, use_bias); + } + + /// Executes one test + bool run( + ProblemShapeType problem_size, + ElementScalar alpha = ElementScalar(1), + ElementScalar beta = ElementScalar(0), + bool profiling = false, + int iterations = 20, + bool use_bias = true) + { + // Fail test if insufficient CUDA device + if (!impl_.sufficient()) { + std::cout << "Test failed due to insufficient CUDA device." << std::endl; + return false; + } + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments; + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + if (not profiling) { + impl_.sm_count = std::min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); + hw_info.sm_count = impl_.sm_count; + } + else { + impl_.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + hw_info.sm_count = impl_.sm_count; + } + + /// Initializes data structures + /// A/B/C0/D Tensor + initialize(problem_size); + initialize_bias(problem_size); + + if constexpr (IsBinaryOp1Enabled) { + initialize_c1(problem_size); + } + + arguments = typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + { impl_.collective_mma_inputs.tensor_A.device_data(), impl_.collective_mma_inputs.stride_a, + impl_.collective_mma_inputs.tensor_B.device_data(), impl_.collective_mma_inputs.stride_b, + impl_.mma_promotion_interval + }, + { // Epilogue arguments + { alpha, beta }, // ThreadOp arguments + impl_.collective_epilogue.stride_c, + impl_.collective_epilogue.tensor_D.device_data(), + impl_.collective_epilogue.stride_d, + use_bias ? bias.device_data() : nullptr, + impl_.collective_epilogue.tensor_C.device_data(), + tensor_C1.device_data() + }, // Epilogue arguments end + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + if (profiling) { + return impl_.profile(problem_size, iterations, gemm_op, arguments, workspace); + } + else { + cudaError_t result; + status = gemm_op.initialize(arguments, workspace.get()); + status = gemm_op.run(); + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << "Error at Kernel Sync."; + return false; + } + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha, beta, use_bias); + if (!passed) { + std::cout << "Error : Failed : with alpha: " << float(alpha) + << ", beta: " << float(beta) + << ", use_bias: " << use_bias + << "\n"; + } + + return passed; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllTensorBroadcast(bool use_bias=true) { + using ElementScalar = typename Gemm::GemmKernel::CollectiveEpilogue::ElementScalar; + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + int max_alignment = std::max(Gemm::kAlignmentA, Gemm::kAlignmentB); + std::vector problem_size_m = {max_alignment, 512 - 3 * max_alignment}; + std::vector problem_size_n = {max_alignment, 512 - 2 * max_alignment}; + + if constexpr (cute::is_same_v) { + problem_size_m.push_back(768); + problem_size_n.push_back(768); + } + + constexpr int Stages = Gemm::GemmKernel::DispatchPolicy::Stages; + constexpr int TileShapeK = cute::size<2>(typename Gemm::GemmKernel::TileShape{}); + + std::vector problem_size_k = {max_alignment, TileShapeK * (Stages + 1) - max_alignment}; + + Testbed3xTensorBroadcast testbed; + bool passed = true; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + for (bool use_bias : {true, false}) { + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20, // iterations + use_bias + ); + + if (!passed) { + return false; + } + } + } + } + } + + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + auto problem_size = ProblemShapeType{256 + max_alignment, 256 + max_alignment, 160 + max_alignment, /* l */ 3}; + passed = testbed.run( + problem_size, + cutlass::from_real(1), + cutlass::from_real(1), + false, // profiling + 20 // iterations + ); + if (!passed) { + return false; + } + } + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6ae7b864cb272782da4920ffc038830d3b5984b2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = + typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor(cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, uint64_t seed) { + if (dist_kind == cutlass::Distribution::Uniform) { + int scope = (cutlass::sizeof_bits::value == 8) ? 2 : 8; + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope, + -scope, 0); + } else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, -1); + } else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), + view.capacity()); + } else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run(cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waives test if CUDA device is insufficient + if (!sufficient()) { + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor + tensor_A(problem_size.mk()); + + cutlass::HostTensor + tensor_B(problem_size.kn()); + + cutlass::HostTensor + tensor_C(problem_size.mn()); + + cutlass::HostTensor + tensor_D(problem_size.mn()); + + cutlass::HostTensor + reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, tensor_A.device_ref(), tensor_B.device_ref(), + tensor_C.device_ref(), tensor_D.device_ref(), {alpha, beta}}; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + cudaError_t error = cudaGetLastError(); + std::cerr << "This test is not supported: " << cudaGetErrorString(error) << "\n"; + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, alpha, tensor_A.host_ref(), tensor_B.host_ref(), beta, + reference_D.host_ref(), ElementAccumulator(0)); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + std::stringstream fname; + + fname << "error_Gemm_device_" << problem_size.m() << "x" + << problem_size.n() << "x" << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" << Gemm::ThreadblockShape::kN + << "x" << Gemm::ThreadblockShape::kK << "_" << Gemm::WarpShape::kM + << "x" << Gemm::WarpShape::kN << "x" << Gemm::WarpShape::kK + << ".txt"; + + std::ofstream file(fname.str()); + + file << "problem: " << problem_size << ", alpha: " << alpha + << ", beta: " << beta << "\n\n"; + + file << "A =\n" + << tensor_A.host_view() << "\nB =\n" + << tensor_B.host_view() << "\nC =\n" + << tensor_C.host_view() << "\n\nReference =\n" + << reference_D.host_view() << "\nComputed =\n" + << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = {16, 528}; + + int problem_size_n[] = {16, 528}; + + int problem_size_k[] = {Gemm::InstructionShape::kK, + Gemm::ThreadblockShape::kK * Gemm::kStages + + Gemm::InstructionShape::kK}; + + double problem_alpha[] = {1.0}; + + // TODO Try non zero beta value after multistaged epilogue is implemented + double problem_beta[] = {0.0}; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + passed = + run({m, n, k}, ElementCompute(alpha), ElementCompute(beta)); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..e309208bb4311253be5b7366841164eb62748bab --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/multistage_testbed_interleaved.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct MultistageInterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + MultistageInterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerMultiprocessor < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, Gemm::ThreadblockShape::kK * Gemm::kStages + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 0.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py new file mode 100644 index 0000000000000000000000000000000000000000..a180028205abb689436c73403eea82758ade7da9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/simt_sm50.py @@ -0,0 +1,341 @@ +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# this file creates the test/unit/gemm/device simt tests + + +outputDir = "" + +################################################################################ +# parameters +# Edge - for tiles, the edges represent the length of one side +# Ratio - the maximum ratio between 2 edges, limits the skinnyness of tiles +# MaxEdge - maximum length of each edge +# Min/Max - minimum/maximum of the product of edge lengths +################################################################################ + +warpsPerThreadblockEdge = [1, 2, 4, 8, 16] +warpsPerThreadblockRatio = 2 +warpsPerThreadblockMax = 16 +# NOTE 1x32 and 2x16 warp tile shapes fail validation for ~10% of cases + +warpShapeEdges = [8, 16, 32, 64, 128, 256] +warpShapeRatio = 4 +warpShapeMax = 64*64 +warpShapeMin = 8*8 + +threadblockEdgeMax = 256 + +# char, type bits/elem, max tile, L0 threadblock tiles +precisions = [ + ["c", "cutlass::complex", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["q", "cutlass::Quaternion", 64, 64*128, [ [ 64, 128], [ 64, 32] ] ], + ["d", "double", 64, 64*64, [ [ 64, 64], [ 32, 32] ] ], + ["h", "cutlass::half_t", 16, 128*256, [ [256, 128], [ 64, 128], [ 64, 32] ] ], + ["i", "int", 32, 128*128, [ [128, 64], [ 16, 32] ] ], + ["s", "float", 32, 128*128, [ [128, 256], [128, 128], [ 64, 64] ] ], + ["z", "cutlass::complex", 128, 64*64, [ [ 32, 64], [ 16, 32] ] ], + ] +# L1 will have a single kernel for every unique shape +# L2 will have everything else + +transposes = [ + [False, False], + [False, True], + [True, False], + [True, True] + ] + +################################################################################ +# warps per threadblock +################################################################################ +warpsPerThreadblocks = [] +for warpsPerThreadblock0 in warpsPerThreadblockEdge: + for warpsPerThreadblock1 in warpsPerThreadblockEdge: + if warpsPerThreadblock0 / warpsPerThreadblock1 <= warpsPerThreadblockRatio and warpsPerThreadblock1 / warpsPerThreadblock0 <= warpsPerThreadblockRatio and warpsPerThreadblock0 * warpsPerThreadblock1 <= warpsPerThreadblockMax: + warpsPerThreadblocks.append([warpsPerThreadblock0, + warpsPerThreadblock1]) +print("WarpsPerThreadblocks",warpsPerThreadblocks) + +################################################################################ +# warp shapes +################################################################################ +warpNumThreads = 32 +warpShapes = [] +for warp0 in warpShapeEdges: + for warp1 in warpShapeEdges: + if warp0 / warp1 <= warpShapeRatio and warp1 / warp0 <= warpShapeRatio and warp0*warp1 <= warpShapeMax and warp0*warp1 > warpShapeMin: + warpShapes.append([warp0, warp1]) +print("WarpShapes", warpShapes) + +numL0 = 0 +numL1 = 0 +numL2 = 0 + +################################################################################ +# create kernels +# create a file for each precision/transpose +# each file contains many tile sizes +################################################################################ + +# precisions +for precision in precisions: + + # get precision char + precisionChar = precision[0] + precisionType = precision[1] + precisionBits = precision[2] + threadblockMaxElements = precision[3] + threadblockTilesL0 = precision[4] + + # transposes + for transpose in transposes: + + # get transpose char + columnMajorA = transpose[0] + columnMajorB = transpose[1] + transCharA = "n" if columnMajorA else "t" + transCharB = "n" if columnMajorB else "t" + + # open file + fileName="simt_%sgemm_%s%s_sm50.cu" % (precisionChar, transCharA, transCharB) + print("\n", fileName) + filePath = "%s%s" % (outputDir, fileName) + out = open(filePath, "w+") + + # write file header + out.write("/***************************************************************************************************\n" +" * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. \n" +" * SPDX-License-Identifier: BSD-3-Clause \n" +" * \n" +" * Redistribution and use in source and binary forms, with or without \n" +" * modification, are permitted provided that the following conditions are met: \n" +" * \n" +" * 1. Redistributions of source code must retain the above copyright notice, this \n" +" * list of conditions and the following disclaimer. \n" +" * \n" +" * 2. Redistributions in binary form must reproduce the above copyright notice, \n" +" * this list of conditions and the following disclaimer in the documentation \n" +" * and/or other materials provided with the distribution. \n" +" * \n" +" * 3. Neither the name of the copyright holder nor the names of its \n" +" * contributors may be used to endorse or promote products derived from \n" +" * this software without specific prior written permission. \n" +" * \n" +" * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS \"AS IS\" \n" +" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE \n" +" * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE \n" +" * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE \n" +" * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL \n" +" * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR \n" +" * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER \n" +" * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, \n" +" * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE \n" +" * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \n" +" *\n" +" **************************************************************************************************/\n" +"/*! \\file\n" +" \\brief Tests for device-wide GEMM interface\n" +"*/\n" +"\n" +"#include \n" +"\n" +"#include \"cutlass/cutlass.h\"\n" +"#include \"cutlass/gemm/device/gemm.h\"\n" +"#include \"cutlass/numeric_types.h\"\n" +"\n" +"#include \"../../common/cutlass_unit_test.h\"\n" +"\n" +"#include \"cutlass/util/host_tensor.h\"\n" +"#include \"cutlass/util/tensor_view_io.h\"\n" +"#include \"cutlass/util/reference/host/tensor_fill.h\"\n" +"#include \"cutlass/util/reference/host/tensor_copy.h\"\n" +"#include \"cutlass/util/reference/host/tensor_compare.h\"\n" +"#include \"cutlass/util/reference/host/gemm.h\"\n" +"\n" +"#include \"testbed.h\"\n" +"\n") + foundThreadblockTilesL0 = {} + foundThreadblockTilesL1 = {} + + ######################################################################## + # for each combination of tile sizes + ######################################################################## + for warpsPerThreadblock in warpsPerThreadblocks: + for warpShape in warpShapes: + warpThreadsM = 0 + if warpShape[0] > warpShape[1]: + warpThreadsM = 8 + else: + warpThreadsM = 4 + warpThreadsN = warpNumThreads / warpThreadsM + + # skip shapes with conflicting rectangularity + # they are unlikely to be fastest + blockG = warpsPerThreadblock[0] > warpsPerThreadblock[1] + blockL = warpsPerThreadblock[0] < warpsPerThreadblock[1] + warpG = warpShape[0] > warpShape[1] + warpL = warpShape[0] < warpShape[1] + + blockG2 = warpsPerThreadblock[0] > warpsPerThreadblock[1]*2 + blockL2 = warpsPerThreadblock[0]*2 < warpsPerThreadblock[1] + warpG2 = warpShape[0] > warpShape[1]*2 + warpL2 = warpShape[0]*2 < warpShape[1] + + if blockG2 and warpL: continue + if blockL2 and warpG: continue + if warpG2 and blockL: continue + if warpL2 and blockG: continue + + # check threadblock ratios and max + threadblockTile = [warpShape[0]*warpsPerThreadblock[0], + warpShape[1]*warpsPerThreadblock[1]] + if threadblockTile[0] * threadblockTile[1] > threadblockMaxElements: continue + if threadblockTile[0] > threadblockEdgeMax: continue + if threadblockTile[1] > threadblockEdgeMax: continue + totalThreads = warpNumThreads*warpsPerThreadblock[0]*warpsPerThreadblock[1] + + # calculate unroll + # ensure that every iteration at least a full load of A,B are done + unrollMin = 8 + unrollMin0 = totalThreads / threadblockTile[0] + unrollMin1 = totalThreads / threadblockTile[1] + unroll = max(unrollMin, unrollMin0, unrollMin1) + + threadTileM = warpShape[0] / warpThreadsM + threadTileN = warpShape[1] / warpThreadsN + if threadTileM < 2 or threadTileN < 2: continue + if threadTileM*threadTileN*precisionBits > 8*8*32: continue + + # epilogue currently only supports N < WarpNumThreads + if threadblockTile[1] < warpNumThreads: continue + + # limit smem + smemBitsA = threadblockTile[0]*unroll*2*precisionBits + smemBitsB = threadblockTile[1]*unroll*2*precisionBits + smemKBytes = (smemBitsA+smemBitsB)/8/1024 + if (smemKBytes > 48): continue + + # test level 0 + testLevel = -1 + for tileId in range(0, len(threadblockTilesL0)): + tbTile = threadblockTilesL0[tileId] + if tbTile[0] == threadblockTile[0] and tbTile[1] == threadblockTile[1]: + if tuple(tbTile) not in foundThreadblockTilesL0: + testLevel = 0 + numL0 += 1 + foundThreadblockTilesL0[tuple(tbTile)] = True + + # test level 1 + if testLevel < 0: + threadblockTileAlreadyUsed = False + if tuple(threadblockTile) not in foundThreadblockTilesL1: + testLevel = 1 + numL1 += 1 + foundThreadblockTilesL1[tuple(threadblockTile)] = True + + # test level 2 + if testLevel < 0: + testLevel = 2 + numL2 += 1 + + ################################################################ + # write this tile to file + ################################################################ + + print("%ix%ix%i__%ix%i_%ix%i_%ix%i L%i" % ( + threadblockTile[0], threadblockTile[1], unroll, + threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], testLevel)) + + out.write("////////////////////////////////////////////////////////////////////////////////\n" + "// Elements / Thread: %3i x %3i\n" + "// Threads / Warp: %3i x %3i\n" + "// Warps / Block: %3i x %3i\n" + "// Threadblock: %3i x %3i x %2i\n" + % ( threadTileM, threadTileN, + warpThreadsM, warpThreadsN, + warpsPerThreadblock[0], warpsPerThreadblock[1], + threadblockTile[0], threadblockTile[1], unroll + ) + ) + + out.write("CUTLASS_TEST_L%i(SM50_device_%sgemm_%s%s, %ix%ix%i_%ix%ix1_%ix%i_%ix%i_%ix%i, {\n" % ( + testLevel, + precisionChar, + transCharA, + transCharB, + threadblockTile[0], + threadblockTile[1], + unroll, + warpShape[0], + warpShape[1], + threadTileM, + threadTileN, + warpThreadsM, + warpThreadsN, + warpsPerThreadblock[0], + warpsPerThreadblock[1] + )) + out.write(" using precision = %s;\n" % precisionType) + out.write(" using ThreadblockShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n" % ( + threadblockTile[0], + threadblockTile[1], + unroll)) + out.write(" using WarpShape = cutlass::gemm::GemmShape<%i, %i, %i>;\n\n" % ( + warpShape[0], + warpShape[1], + unroll)) + out.write(" static int const kEpilogueElementsPerAccess = 1;\n" + " using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>;\n" + " using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination<\n" + " precision, kEpilogueElementsPerAccess, precision, precision>;\n\n") + + out.write(" using Gemm = cutlass::gemm::device::Gemm<\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::%sMajor,\n" + " precision, cutlass::layout::RowMajor,\n" + " precision,\n" + " cutlass::arch::OpClassSimt,\n" + " cutlass::arch::Sm50,\n" + " ThreadblockShape, WarpShape, InstructionShape,\n" + " EpilogueOutputOp,\n" + " cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,\n" + " 2 // Stages\n" + " >;\n" % ( + "Column" if columnMajorA else "Row", + "Column" if columnMajorB else "Row", + )) + out.write(" EXPECT_TRUE(test::gemm::device::TestAllGemm());\n" + "} )\n\n") + + + out.close() +print("NumKernels:", numL0, numL1, numL2) + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63ffc3281dd2b9e9f74e0024c73da00628331dd4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/sm90_evt_operations.hpp @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Host reference and operations for Sm90 EVT unit test +*/ +#pragma once +#include "gemm_testbed_3x_evt.hpp" + +////////////////////////////////////////////////////////////////////////////// +/// Host references used for testing +namespace test::gemm::device { +template +using HEVT = HostTreeVisitor; + +template +using HDAG = HostTopoVisitor; + +template +using HST = HostSplitTreeVisitor; + +/// D = alpha * acc + beta * C + AuxLoad +template +class HostEVTAuxLoad { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = alpha * acc + beta * C + per-column bias +template +class HostPerColBias { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using RowBroadcastNode = HostRowBroadcast; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// D = beta * C + Graph(relu(alpha * acc + aux) + aux) +/// Testing EVT - DAG structure +template +class HostEVTDAG { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using AuxLoadNode = HostAuxLoad; + using DAGNode = HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. alpha + cute::tuple<>, // 1. acc + cute::tuple<>, // 2. aux load + cute::tuple, // 3. alpha * acc + aux load + cute::tuple, // relu(alpha * acc + aux load) + cute::tuple // relu(alpha * acc + aux load) + aux load + >, + ScalarAlpha, + AccFetchNode, + AuxLoadNode, + HostCompute, + HostCompute, + HostCompute + >; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using EVTModule = HEVT, TernaryCompute1>; +}; + +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +/// Testing DAG - EVT +template +class HostDAGEVT { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTNode = HEVT< + HostAuxStore, + HEVT< + HostCompute, + HostScalarBroadcast<2>, + HostAccumulator<>, + HostAuxLoad + > + >; + using EVTModule = HEVT< + HostAuxStore, + HDAG< + float, + cute::tuple< + cute::tuple<>, // 0. EVT + cute::tuple<>, // 1. per-row bias + cute::tuple, // 2. EVT + per-row bias + cute::tuple // 3. maximum(EVT + per-row bias, EVT) + >, + EVTNode, + HostColBroadcast>, + HostCompute, + HostCompute + > + >; +}; + +/// Xreduce(alpha * acc + beta * C) +template +class HostReduce { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using ElementD = typename Gemm::GemmKernel::ElementC; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using ScalarAlpha = HostScalarBroadcast<1>; + using AccFetchNode = HostAccumulator<>; + using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; + using ScalarBeta = HostScalarBroadcast<1>; + using CLoadNode = HostAuxLoad; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using ReduceNode = HEVT; + using EVTModule = HEVT, ReduceNode>; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template class ActivationFn, class ElementD> +class HostScaledLinCombPerRowBiasEltAct { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + using EVTModule = HEVT< + HostAuxStore, + HEVT< + HostCompute::template Op>, // activation(Z) * scaled_d + HEVT< + HostCompute, // activation(Z) + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template class ActivationFn, class ElementD, class ElementAux = ElementD> +class HostScaledLinCombPerRowBiasEltActAmaxAux { +public: + using ElementC = typename Gemm::GemmKernel::ElementC; + using LayoutC = cutlass::detail::StrideToLayoutTagC_t; + using LayoutD = cutlass::detail::StrideToLayoutTagC_t; + + template + using amax = cutlass::maximum_absolute_value_reduction; + using EVTModuleAuxFp8 = HEVT< + HostAuxStore, + HST, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + >, + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HostAccumulator<> // Z + > + >, + HostScalarBroadcast<1> // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + HEVT< + HostAuxStore, + HEVT< + HostCompute, + HEVT< + HostScalarReduce, + HostAccumulator<> + >, + HostScalarBroadcast<1> + > + > + > + >; + + using EVTModuleAuxNotFp8 = HEVT< + // D = activation(Z) * scaled_d, amax_d = max(abs(elements in D)) + HostAuxStore, + HEVT< + HostCompute::template Op>, + HEVT< + HostScalarReduce, + HEVT< + HostCompute, //activation(Z) * scaled_d + HEVT< + // Aux = Z + HostAuxStore, + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + HEVT< + HostCompute, + HostScalarBroadcast<1, 2, cute::Stride>, // scale_c * beta + HostAuxLoad, // C + HEVT< + HostCompute, + HostScalarBroadcast<1, 3, cute::Stride>, // scale_a * scale_b * alpha + HostAccumulator<>, + HostColBroadcast> + > + > + > + > + >, + HostScalarBroadcast<1> // scale_d + > + >; + + using EVTModule = cute::conditional_t, EVTModuleAuxFp8, EVTModuleAuxNotFp8>; + +}; +} // namespace test::gemm::device + +////////////////////////////////////////////////////////////////////////////// +namespace cutlass::epilogue { +namespace fusion { + +namespace detail { + +template +struct maximum_with_default_nan_propagation : maximum {}; + +} // namespace detail + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoad +template< + class EpilogueDescriptor, + class AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoad = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + typename AuxLoadDescriptor::CopyOpS2R // aux load + > + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + AuxLoadNoSmem +template< + class EpilogueDescriptor, + class ElementAux, + class StrideAux, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombAuxLoadNoSmem = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad<0, void, ElementAux, StrideAux, void, void> // aux load + > + >; + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// beta * C + Graph(alpha * acc + gamma + acc) +template< + typename EpilogueDescriptor, + typename AuxLoadDescriptor, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEVTDAG = + Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, // 0. alpha + cute::seq<>, // 1. acc + cute::seq<>, // 2. aux load + cute::seq<1, 0, 2>, // 3. alpha * acc + aux load + cute::seq<3>, // relu(alpha & acc + aux load) + cute::seq<2, 4> // relu(alpha * acc + aux load) + aux load + >, + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90AuxLoad< + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, + Sm90Compute, + Sm90Compute, + Sm90Compute + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// Example DAG +/// EVT = alpha * acc + C +/// D = Graph(maximum(EVT + per-row bias, EVT)) +template< + class EpilogueDescriptor, + class AuxStoreDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDAGEVT = + Sm90TopologicalVisitor< + ElementCompute, + cute::tuple< + cute::seq<>, + cute::seq<>, + cute::seq<1, 0>, + cute::seq<0, 2> + >, + Sm90EVT< + Sm90AuxStore< + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, + typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, + Sm90EVT, + Sm90ScalarBroadcast, + Sm90AccFetch, + Sm90SrcFetch + > + >, + Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute>, + Sm90Compute, + Sm90Compute + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = alpha * acc + beta * C + per-column bias +template< + class EpilogueDescriptor, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias, ElementCompute> + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-column reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColumnReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = per-row reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; + + +////////////////////////////////////////////////////////////////////////////// +/// D = scalar reduce(alpha * acc + beta * C) +template< + template class RegReduceFn, + template class GmemReduceFn, + class ElementReduce, + class ElementOutput, + class ElementCompute, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombScalarReduce = + Sm90EVT, // per column reduce + Sm90EVT, // beta * C + alpha * acc + Sm90ScalarBroadcast, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast, // alpha + Sm90AccFetch // acc + > + > + >; +} // namespace fusion + +} // namespace cutlass::epilogue diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..0007666cdd084f35015200e36fd47f75971f6c1c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed.h @@ -0,0 +1,639 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed_utils.h" +#include "testbed_universal.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Testbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + typename Gemm::LayoutA::Stride stride_factor_A; + typename Gemm::LayoutB::Stride stride_factor_B; + typename Gemm::LayoutC::Stride stride_factor_C; + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + Testbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(typename Gemm::LayoutA::Stride()), + stride_factor_B(typename Gemm::LayoutB::Stride()), + stride_factor_C(typename Gemm::LayoutC::Stride()), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + Testbed( + typename Gemm::LayoutA::Stride stride_factor_A_, + typename Gemm::LayoutB::Stride stride_factor_B_, + typename Gemm::LayoutC::Stride stride_factor_C_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + stride_factor_A(stride_factor_A_), + stride_factor_B(stride_factor_B_), + stride_factor_C(stride_factor_C_), + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mk(), stride_factor_A)); + tensor_B.resize(problem_size.kn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.kn(), stride_factor_B)); + tensor_C.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + tensor_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C)); + reference_D.resize(problem_size.mn(), cutlass::layout::Affine2Layout_Factory::layout_factory(problem_size.mn(), stride_factor_C), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at(cutlass::make_Coord(0, 0)) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0) + << "tensor_D (size " << tensor_D.size() << ") has nonpositive norm"; + } + if (reference_D.size() > 1) { + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0) + << "reference_D (size " << reference_D.size() << ") has nonpositive norm"; + } + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << "reference_D does not equal tensor_D"; + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Determine if the CUDA device is sufficient to run the kernel + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "split_k_slices: " << split_k_slices << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op.initialize returned with error " << to_string(status) + << ", indicating that this test is not supported. Last CUDA error: " + << cudaGetErrorString(cudaGetLastError()); + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + try { + status = gemm_op(); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "gemm_op() threw a std::exception: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "gemm_op() threw an exception of unknown type"; + throw; + } + EXPECT_TRUE(status == cutlass::Status::kSuccess) + << "gemm_op failed with error " << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + EXPECT_TRUE(passed) << "Error: split_k_slices = " << split_k_slices + << ", alpha: " << alpha; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmBasic( + const typename Gemm::LayoutA::Stride& stride_factor_A = typename Gemm::LayoutA::Stride(), + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = { + kAlignmentK, Gemm::ThreadblockShape::kK * (Gemm::kStages + 1) - kAlignmentK}; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + Testbed testbed(stride_factor_A, stride_factor_B, stride_factor_C); + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + if (split_k > 1 && k / Gemm::ThreadblockShape::kK < split_k) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestAllGemmBasic: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemm( + const typename Gemm::LayoutA::Stride& stride_factor_A, + const typename Gemm::LayoutB::Stride& stride_factor_B = typename Gemm::LayoutB::Stride(), + const typename Gemm::LayoutC::Stride& stride_factor_C = typename Gemm::LayoutC::Stride()) +{ + // Test basic GEMM with non-default stride factors + return TestAllGemmBasic(stride_factor_A, stride_factor_B, stride_factor_C); +} + +template +bool TestAllGemm() +{ +#ifdef NDEBUG + // Non-debug builds also test basic GEMM with default stride factors + if (!TestAllGemmBasic()) { + return false; + } +#endif // NDEBUG + + // Test universal GEMM +#if 0 + // Define the universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversal< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<> // ThreadblockSwizzle + >; +#else + // Define the streamk universal kernel + using UniversalKernel = cutlass::gemm::kernel::GemmUniversalStreamk< + typename Gemm::GemmKernel::Mma, // Mma + typename Gemm::GemmKernel::Epilogue, // Epilogue + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK // ThreadblockSwizzle + >; +#endif + + // Define the universal adaptor + using UniversalGemm = cutlass::gemm::device::GemmUniversalAdapter; + + // Test universal GEMM + return TestAllGemmUniversal(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmPerf(int iterations = 1) { + bool passed = true; + + int problem_size_m[] = { 2048 }; + + int problem_size_n[] = { 4352 }; + + int problem_size_k[] = { 4096 }; + + int split_k_slices[] = { 1 }; + double problem_alpha[] = { 1 }; + double problem_beta[] = { 0.0 }; + + Testbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + for (int i = 0; i < iterations; i++){ + try { + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + } + catch (std::exception const& e) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: " << e.what(); + throw; + } + catch (...) { + EXPECT_TRUE(false) << "TestGemmPerf: testbed.run threw an " + "exception {alpha: " << alpha << ", beta: " << beta << ", m: " + << m << ", n: " << n << ", k: " << k << "}: (unknown)"; + throw; + } + } + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..add984ca3b9a0c05325b93cf52cbadd710527ba6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_complex.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedComplex : public Testbed { + + using Base = Testbed; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + + // + // Methods + // + + TestbedComplex( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex( + problem_size, + alpha, + this->tensor_A.host_ref(), + Gemm::kTransformA, + this->tensor_B.host_ref(), + Gemm::kTransformB, + beta, + this->tensor_C.host_ref(), + this->reference_D.host_ref(), + ElementAccumulator(0) + ); + + return this->compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Initialize workspace + // + + this->initialize(problem_size); + + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmComplex() { + bool passed = true; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = + cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + int problem_size_m[] = { + kAlignment, 512 - 3*kAlignment + }; + + int problem_size_n[] = { + kAlignment, 512 - 2*kAlignment + }; + + int problem_size_k[] = { + kAlignment, 128 - kAlignment + }; + + int split_k_slices[] = { + 1, 2, 3 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + TestbedComplex testbed; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + if (!Gemm::kSplitKSerial && split_k > 1) { + continue; + } + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h new file mode 100644 index 0000000000000000000000000000000000000000..eca0b0ae0decf3293f6f73cb6ebbc5b5735a8e49 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_broadcast.h @@ -0,0 +1,670 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithBroadcastReferenceOp { + + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + using ElementCompute = typename OutputOp::ElementCompute; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + typename OutputOp::BinaryOp binary_op; + typename OutputOp::ElementwiseOp elementwise_op; + + GemmWithBroadcastReferenceOp() { } + + void operator()(ElementZ &Z, ElementT &T, ElementCompute gemm, ElementCompute bias) { + + ElementCompute t_full = binary_op(gemm, bias); + + if (OutputOp::kStoreT) { + T = ElementT(t_full); + } + + if (OutputOp::kStoreZ) { + ElementCompute z_full = elementwise_op(t_full); + Z = ElementZ(z_full); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Fused testbed +// +// Y = GEMM(AB, C) +// +// T[i, j] = BinaryOp(Y[i, j], Broadcast[i]) +// +// Z[i, j] = Elementwise(T[i, j]) +// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +struct TestbedGemmWithBroadcast { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using OutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename OutputOp::ElementCompute; + using ElementVector = typename OutputOp::ElementVector; + using ElementZ = typename OutputOp::ElementZ; + using ElementT = typename OutputOp::ElementT; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; // Input A + cutlass::HostTensor tensor_B; // Input B + cutlass::HostTensor tensor_C; // Input C + cutlass::HostTensor tensor_Broadcast; // Input Broadcast + + cutlass::HostTensor tensor_Z; + cutlass::HostTensor tensor_T; + + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor tensor_Y_ref; + cutlass::HostTensor tensor_Z_ref; + cutlass::HostTensor tensor_T_ref; + + + // + // Methods + // + + TestbedGemmWithBroadcast( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_Z.resize(problem_size.mn()); + tensor_T.resize(problem_size.mn()); + tensor_Broadcast.resize({ + problem_size.m(), + 1 + }); + + tensor_C_ref.resize(problem_size.mn()); + tensor_Y_ref.resize(problem_size.mn()); + tensor_Z_ref.resize(problem_size.mn()); + tensor_T_ref.resize(problem_size.mn()); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Broadcast.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_Broadcast.sync_device(); + + tensor_Z.sync_device(); + tensor_T.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Z.sync_host(); + tensor_T.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (OutputOp::kStoreZ) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Z_ref.host_view()), 0); + } + + if (OutputOp::kStoreT) { + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_T_ref.host_view()), 0); + } + + bool passed = true; + float norm_diff = 0; + + if (OutputOp::kStoreZ) { + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_Z_ref.host_view(), tensor_Z.host_view(), float()); + passed = (norm_diff <= 0.1f); + EXPECT_LT(norm_diff, 0.1f) << " tensor_Z is incorrect"; + } + + if (OutputOp::kStoreT) { + + norm_diff = cutlass::reference::host::TensorNormDiff(tensor_T_ref.host_view(), tensor_T.host_view(), float()); + passed = (passed && (norm_diff <= 0.1f)); + + EXPECT_LT(norm_diff, 0.1f) << " tensor_T is incorrect"; + } + + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("errors_testbed_gemm_with_broadcast.txt"); + + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nZ =\n" << tensor_Z.host_view() + << "\nT =\n" << tensor_T.host_view() + << "\n\n" + << "\nY_ref =\n" << tensor_Y_ref.host_view() + << "\nZ_ref =\n" << tensor_Z_ref.host_view() + << "\nT_ref =\n" << tensor_T_ref.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + tensor_Y_ref.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute tensor Z and tensor T + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + + ElementZ z; + ElementT t; + + reference_op(z, t, tensor_Y_ref.at({m, n}), tensor_Broadcast.at({m, 0})); + + if (OutputOp::kStoreZ) { + tensor_Z_ref.at({m, n}) = z; + } + + if (OutputOp::kStoreT) { + tensor_T_ref.at({m, n}) = t; + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_Z.device_data(), + tensor_Broadcast.device_data(), + tensor_T.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = true; + + passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Broadcast(tensor_Broadcast.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Z(tensor_Z.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_T(tensor_T.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_Z.get() + tensor_Z.capacity() * (i % kWorkspaces), + profiling_tensor_Broadcast.get() + tensor_Broadcast.capacity() * (i % kWorkspaces), + profiling_tensor_T.get() + tensor_T.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_Z.layout().stride(0), + 0, // This must be zero + tensor_T.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestGemmWithBroadcast( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp = GemmWithBroadcastReferenceOp +> +bool TestAllGemmWithBroadcast() { + + int M_problems[] = {8, 136, 264, 520}; + int N_problems[] = {8, 136, 264, 520}; + int K_problems[] = {8, 136, 264, 520}; + double alpha_problems[] = {1.25, 2.25}; + double beta_problems[] = {0, 1, 2.0}; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + + TestbedGemmWithBroadcast testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + 1, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta; + + if (!passed) { + + return passed; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h new file mode 100644 index 0000000000000000000000000000000000000000..af3629ccfb87e09e80b85af508379780d6428dc5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_gemm_with_reduction.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmWithReductionReference { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::ElementCompute; + using ElementC = typename Gemm::ElementC; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + // + // Data members + // + + BinaryOp binary_op; + + // + // Methods + // + + GemmWithReductionReference() { } + + ElementCompute operator()( + ElementAccumulator d_y, + ElementT t) { + + return binary_op(ElementCompute(d_y), ElementCompute(t)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename ReferenceOp +> +struct TestbedGemmWithReduction { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementT = typename Gemm::GemmKernel::Epilogue::ElementTensor; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Reduction; + cutlass::HostTensor tensor_Tensor; + cutlass::HostTensor tensor_C_ref; + cutlass::HostTensor reference_d_Y; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Reduction; + + // + // Methods + // + + TestbedGemmWithReduction( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 1; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + for (int m = 0; m < view.extent().row(); ++m) { + for (int n = 0; n < view.extent().column(); ++n) { + //view.at({m, n}) = Element(float(((idx ++) % 17) - 8)); + view.at({m, n}) = (n == 0 ? Element(m) : Element()); + + } + } + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + + tensor_Reduction.resize({ + problem_size.m(), + (problem_size.n() - 1 + Gemm::ThreadblockShape::kN) / Gemm::ThreadblockShape::kN + }); + + tensor_Tensor.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + reference_d_Y.resize(problem_size.mn(), false); + tensor_C_ref.resize(problem_size.mn(), false); + reference_Reduction.resize({problem_size.m(), 1}, false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + EXPECT_TRUE(initialize_tensor(tensor_Tensor.host_view(), init_C, seed + 2020)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + for (int m = 0; m < tensor_C_ref.extent().row(); ++m) { + for (int n = 0; n < tensor_C_ref.extent().column(); ++n) { + tensor_C_ref.at({m, n}) = ElementAccumulator(tensor_C.at({m, n})); + } + } + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Reduction.sync_device(); + tensor_Tensor.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + tensor_Reduction.sync_host(); + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Reduction.host_view()), 0); + + bool passed = true; + for (int m = 0; m < tensor_Reduction.extent().row(); ++m) { + + ElementAccumulator reduced_value = ElementAccumulator(); + for (int j = 0; j < tensor_Reduction.extent().column(); ++j) { + reduced_value += tensor_Reduction.at({m, j}); + } + + if (reduced_value != reference_Reduction.at({m, 0})) { + std::cout << "Error in bias[" << m << "] - Expected: " << reference_Reduction.at({m, 0}) << ", got: " << reduced_value << std::endl; + passed = false; + break; + } + } + EXPECT_TRUE(passed) << "Reduction is incorect."; + + if (!cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view())) { + EXPECT_TRUE(false) << " mismatched reference"; + passed = false; + } + + if (!passed) { + + /* + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors_sm70.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nT = \n" << tensor_Tensor.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view() + << "\n\nReduction =\n" << tensor_Reduction.host_view() << "\n" + << "\nReference reduction =\n" << reference_Reduction.host_view() << "\n"; + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementAccumulator alpha, + ElementAccumulator beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + ElementAccumulator, typename Gemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C_ref.host_ref(), + reference_d_Y.host_ref(), + ElementAccumulator(0) + ); + + using ElementC = typename Gemm::ElementC; + + ReferenceOp reference_op; + + // compute backwards + for (int m = 0; m < problem_size.m(); ++m) { + ElementAccumulator reduced_value = ElementAccumulator(); + for (int n = 0; n < problem_size.n(); ++n) { + ElementAccumulator d_full = reference_op(reference_d_Y.at({m, n}), tensor_Tensor.at({m, n})); + reduced_value += d_full; + reference_D.at({m, n}) = ElementC(d_full); + } + reference_Reduction.at({m, 0}) = reduced_value; + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementAccumulator alpha = ElementAccumulator(1), + ElementAccumulator beta = ElementAccumulator(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Reduction.device_data(), + tensor_Tensor.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + // + // Profile + // + + #if 0 // profiling disabled for now. + + int const kWorkspaces = 100; + + cutlass::DeviceAllocation profiling_tensor_A(tensor_A.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_B(tensor_B.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_C(tensor_C.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_D(tensor_D.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Reduction(tensor_Reduction.capacity() * kWorkspaces); + cutlass::DeviceAllocation profiling_tensor_Tensor(tensor_Tensor.capacity() * kWorkspaces); + + cudaEvent_t events[2]; + for (auto & event : events) { + cudaError_t result = cudaEventCreate(&event); + if (result != cudaSuccess) { + EXPECT_EQ(result, cudaSuccess) << " cudaEventCreate() failed with error " << cudaGetErrorString(result); + return false; + break; + } + } + + int const kWarmupIterations = 5; + int const kProfilingIterations = 100; + + for (int i = 0; i < kWarmupIterations; ++i) { + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + + cudaError_t result = cudaEventRecord(events[0]); + EXPECT_EQ(result, cudaSuccess); + + for (int i = 0; i < kProfilingIterations; ++i) { + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + profiling_tensor_A.get() + tensor_A.capacity() * (i % kWorkspaces), + profiling_tensor_B.get() + tensor_B.capacity() * (i % kWorkspaces), + profiling_tensor_C.get() + tensor_C.capacity() * (i % kWorkspaces), + profiling_tensor_D.get() + tensor_D.capacity() * (i % kWorkspaces), + profiling_tensor_Reduction.get() + tensor_Reduction.capacity() * (i % kWorkspaces), + profiling_tensor_Tensor.get() + tensor_Tensor.capacity() * (i % kWorkspaces), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_Reduction.layout().stride(0), + tensor_Tensor.layout().stride(0), + }; + + gemm_op.initialize(arguments, workspace.get()); + status = gemm_op(); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + } + + result = cudaEventRecord(events[1]); + EXPECT_EQ(result, cudaSuccess); + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess); + + float elapsed_time = 0; + result = cudaEventElapsedTime(&elapsed_time, events[0], events[1]); + EXPECT_EQ(result, cudaSuccess); + + double average_time = double(elapsed_time) / double(kProfilingIterations); + + std::cout << problem_size << ": " << average_time << " ms" << std::endl; + + for (auto & event : events) { + cudaEventDestroy(event); + } + #endif + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmWithReduction( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count = 1, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedGemmWithReduction testbed; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..c7317eb855477e63fe19858ca51cd5722f236eb5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped.h @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24, + 8 * (rand() % 64) + 24); + + if (!i) { + problem = cutlass::gemm::GemmCoord(48, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + // std::cout << "Problem[" << i << "]: " << problem << std::endl; + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a GEMM + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::host::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0) + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Gemm::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = gemm.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the GEMM object + status = gemm.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..f8f08f23c4477745648f1cf8f9e439ae6b5061e2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k.h @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K interface + +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/tensor_view_io.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedGrouped { + + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::EpilogueOutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + int problem_count; + + std::vector problem_sizes_host; + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + // + // Methods + // + + TestbedGrouped( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + // no fill - remain zero + } + + return true; + } + + /// Initializes data structures + void initialize() { + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + + lda_host.resize(problem_count); + ldb_host.resize(problem_count); + ldc_host.resize(problem_count); + ldd_host.resize(problem_count); + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + auto N = 8 * (rand() % 64) + 24; + auto K = 8 * (rand() % 64) + 24; + cutlass::gemm::GemmCoord problem(N, N, K); + + if (!i) { + problem = cutlass::gemm::GemmCoord(16, 16, 8); + } + + problem_sizes_host.at(i) = problem; + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + // Random strides between problems? + } + + problem_sizes_device.reset(problem_count); + problem_sizes_device.copy_from_host(problem_sizes_host.data()); + + lda.reset(problem_count); + ldb.reset(problem_count); + ldc.reset(problem_count); + ldd.reset(problem_count); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + std::vector ptr_A_host(problem_count); + std::vector ptr_B_host(problem_count); + std::vector ptr_C_host(problem_count); + std::vector ptr_D_host(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + + initialize_tensor(cutlass::TensorView(matrix_A.data(), layout_A, extent_A), init_A, seed * 2021); + initialize_tensor(cutlass::TensorView(matrix_B.data(), layout_B, extent_B), init_B, seed * 2022); + initialize_tensor(cutlass::TensorView(matrix_C.data(), layout_C, extent_C), init_C, seed * 2023); + + cutlass::device_memory::copy_to_device(ptr_A_host.at(i), matrix_A.data(), matrix_A.size()); + cutlass::device_memory::copy_to_device(ptr_B_host.at(i), matrix_B.data(), matrix_B.size()); + cutlass::device_memory::copy_to_device(ptr_C_host.at(i), matrix_C.data(), matrix_C.size()); + cutlass::device_memory::copy_to_device(ptr_D_host.at(i), matrix_D.data(), matrix_D.size()); + } + } + + /// Verifies the result is a Rank2K + bool verify( + ElementCompute alpha, + ElementCompute beta) { + + bool passed = true; + + for (int32_t i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem = problem_sizes_host.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.n(), problem.k()}; + MatrixCoord extent_B{problem.n(), problem.k()}; + MatrixCoord extent_C{problem.n(), problem.n()}; + + std::vector matrix_A(layout_A.capacity(extent_A)); + std::vector matrix_B(layout_B.capacity(extent_B)); + std::vector matrix_C(layout_C.capacity(extent_C)); + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_A.data(), block_A.get() + offset_A.at(i), matrix_A.size()); + cutlass::device_memory::copy_to_host(matrix_B.data(), block_B.get() + offset_B.at(i), matrix_B.size()); + cutlass::device_memory::copy_to_host(matrix_C.data(), block_C.get() + offset_C.at(i), matrix_C.size()); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_A(matrix_A.data(), layout_A, extent_A); + cutlass::TensorView view_B(matrix_B.data(), layout_B, extent_B); + cutlass::TensorView view_C(matrix_C.data(), layout_C, extent_C); + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + alpha, + view_A, + Rank2K::kTransformA, + view_B, + Rank2K::kTransformB, + beta, + view_C, + view_Ref, + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Ensure that no input or output is entirely zero + EXPECT_GT(cutlass::reference::host::TensorNorm(view_A), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_B), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_C), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_D), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(view_Ref), 0); + + // Compare against reference + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::ofstream file("testbed_grouped_errors.txt"); + + file + << "problem: " << problem << " [group: " << i << "]\n" + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << view_A + << "\nB =\n" << view_B + << "\nC =\n" << view_C + << "\n\nReference =\n" << view_Ref + << "\nComputed =\n" << view_D; + + return passed; + } + } + + return passed; + } + + /// Executes one test + bool run( + int problem_count, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + this->problem_count = problem_count; + + // Initialize the problem + initialize(); + + int threadblock_count = Rank2K::sufficient(problem_sizes_host.data(), problem_count); + + // Early exit + if (!threadblock_count) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device resources." << std::endl; + } + return true; + } + + // Configure the Rank2K arguments + typename EpilogueOutputOp::Params epilogue_op(alpha, beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_sizes_device.get(), + problem_count, + threadblock_count, + epilogue_op, + ptr_A.get(), + ptr_B.get(), + ptr_C.get(), + ptr_D.get(), + lda.get(), + ldb.get(), + ldc.get(), + ldd.get(), + problem_sizes_host.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k; + + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + cutlass::Status status = rank2k.initialize(args, workspace.get()); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Run the Rank2K object + status = rank2k.run(); + + if (status != cutlass::Status::kSuccess) { + return false; + } + + // Wait for completion + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) + << "Kernel execution error: " << cudaGetErrorString(result); + + if (result != cudaSuccess) { + return false; + } + + // Verify correctness + return verify(alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..e9315e12e8711f50256e4cfe05666201acd614d3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_rank_2k_scheduler.h @@ -0,0 +1,461 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped Rank2K problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/rank_2k_grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/device_kernel.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + static cutlass::FillMode const kFillModeC = FillModeC; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} + + CUTLASS_DEVICE + cutlass::gemm::GemmCoord threadblock_offset(int32_t threadblock_id) const { + int32_t macro_id = threadblock_id / ProblemSizeHelper::OffsetHelper::kThreadblockSkewRatio; + int32_t macro_row = ceil(cutlass::fast_sqrt((2*macro_id) + 2.25) - 0.5) - 1; + int32_t macro_col = macro_id - (((macro_row+1) * macro_row)/2); + + if (FillModeC == cutlass::FillMode::kUpper) { + cutlass::swap(macro_row, macro_col); + } + + int32_t row = ProblemSizeHelper::OffsetHelper::macro_row_to_row(macro_row, threadblock_id); + int32_t col = ProblemSizeHelper::OffsetHelper::macro_col_to_col(macro_col, threadblock_id); + + return cutlass::gemm::GemmCoord(row, col, 0); + } +}; + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + cutlass::gemm::GemmCoord problem_size = problem_visitor.problem_size(); + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + cutlass::gemm::GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + cutlass::gemm::GemmCoord tile_offset = problem_visitor.threadblock_offset(threadblock_idx); + + problem_visitor.advance(gridDim.x); + + // + // Early exit conditions + // 1) Out of range + // 2) Upper-triangular block in lower-triangular problem + // 3) Lower-triangular block in upper-triangular problem + // + + if (grid_shape.m() <= tile_offset.m() || + grid_shape.n() <= tile_offset.n()) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kLower && + (tile_offset.m() + 1) * ProblemVisitor::ThreadblockShape::kM <= tile_offset.n() * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (ProblemVisitor::kFillModeC == cutlass::FillMode::kUpper && + tile_offset.m() * ProblemVisitor::ThreadblockShape::kM >= (tile_offset.n() + 1) * ProblemVisitor::ThreadblockShape::kN) { + continue; + } + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(bool skip_tile_check=false, cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + if (skip_tile_check) { + return true; + } + + return verify(); + } +}; + +template +struct TestbedGroupedRank2KScheduler { + + using BaselinePV = BaselineProblemVisitor, + ThreadblockShape, + PrefetchTileCount, + ThreadCount, + FillModeC>; + + // + // Data members + // + + // Whether to skip checking that the tiles are visited as expected. This is useful + // in cases where ThreadblockShape::kM != ThreadblockShape::kN, for which the grouped + // Rank2K scheduler may assign out-of-bounds tiles that will cause a threadblock to + // exit early, but which are difficult to detect in tests without reimplementing + // this functionality. + bool skip_tile_check; + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedRank2KScheduler(bool skip_tile_check_=false, uint32_t seed_ = 3080): + skip_tile_check(skip_tile_check_), seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + int n = scale_factor * (rand() % 64) + 24; + + cutlass::gemm::GemmCoord problem( + n, + n, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::Rank2KGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + FillModeC>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run(skip_tile_check)); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run(skip_tile_check)); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..bda2704b517ea95052e2c2060b50712b686344f6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_grouped_scheduler.h @@ -0,0 +1,407 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for grouped GEMM problem visitors +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/util/device_memory.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Use simple problem visitor as a baseline +template +struct BaselineProblemVisitor : public cutlass::gemm::kernel::BaseGroupedProblemVisitor { + using Base = cutlass::gemm::kernel::BaseGroupedProblemVisitor; + using Params = typename Base::Params; + static int const kThreadCount = ThreadCount; + + struct SharedStorage {}; + + int32_t tile_count_sum; + SharedStorage &shared_storage; + + // + // Methods + // + CUTLASS_DEVICE + BaselineProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base(params_, block_idx), + shared_storage(shared_storage_) + { + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + tile_count_sum = this->tile_count(grid); + } + + CUTLASS_DEVICE + bool next_tile() { + if (this->tile_idx < tile_count_sum) { + return true; + } + + do { + ++this->problem_idx; + + if (this->problem_idx >= this->params.problem_count) { + return false; + } + + cutlass::gemm::GemmCoord problem = this->problem_size(); + cutlass::gemm::GemmCoord grid = this->grid_shape(problem); + + this->problem_tile_start = tile_count_sum; + tile_count_sum += this->tile_count(grid); + + } while (tile_count_sum <= this->tile_idx); + + return true; + } + + static size_t get_workspace_size(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count) { + return 0; + } + + static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, + int32_t problem_count, + int32_t block_count, + void* host_workspace_ptr) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ProblemVisitorKernel { + struct SharedStorage { + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct Params { + typename ProblemVisitor::Params problem_visitor_params; + int32_t* visited_problems_ptr; + int32_t* visited_tiles_ptr; + int32_t visits_per_block; + + Params(): + visited_problems_ptr(nullptr), + visited_tiles_ptr(nullptr), + visits_per_block(0) {} + + Params(typename ProblemVisitor::Params problem_visitor_params_, + int32_t* visited_problems_ptr_, + int32_t* visited_tiles_ptr_, + int32_t visits_per_block_): + problem_visitor_params(problem_visitor_params_), + visited_problems_ptr(visited_problems_ptr_), + visited_tiles_ptr(visited_tiles_ptr_), + visits_per_block(visits_per_block_) {} + }; + + CUTLASS_DEVICE + void operator()(const Params& params, SharedStorage &shared_storage) { + int32_t store_offset = params.visits_per_block * blockIdx.x; + ProblemVisitor problem_visitor(params.problem_visitor_params, + shared_storage.problem_visitor, + blockIdx.x); + + while (problem_visitor.next_tile()) { + int32_t problem_idx = problem_visitor.problem_index(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (threadIdx.x == 0) { + params.visited_problems_ptr[store_offset] = problem_idx; + params.visited_tiles_ptr[store_offset] = threadblock_idx; + ++store_offset; + } + problem_visitor.advance(gridDim.x); + } + } +}; + +template +struct ProblemVisitorRunner { + using BaseKernel = ProblemVisitorKernel; + using Params = typename BaseKernel::Params; + + Params params; + std::vector host_problem_sizes; + int32_t problem_count; + int32_t threadblock_count; + int32_t visits_per_block; + cutlass::DeviceAllocation visited_problems; + cutlass::DeviceAllocation visited_tiles; + cutlass::DeviceAllocation device_problem_sizes; + cutlass::DeviceAllocation workspace; + std::vector host_visited_problems; + std::vector host_visited_tiles; + + ProblemVisitorRunner(const std::vector& host_problem_sizes_, + int32_t threadblock_count_): + host_problem_sizes(host_problem_sizes_), + problem_count(int32_t(host_problem_sizes_.size())), + threadblock_count(threadblock_count_) {} + + /// Initializes GEMM state from arguments. + cutlass::Status initialize() { + size_t workspace_bytes = ProblemVisitor::get_workspace_size( + host_problem_sizes.data(), + problem_count, + threadblock_count); + + workspace.reset(workspace_bytes); + std::vector host_workspace(workspace_bytes); + + int32_t tile_count = ProblemVisitor::group_tile_count(host_problem_sizes.data(), problem_count); + + ProblemVisitor::host_precompute(host_problem_sizes.data(), problem_count, + threadblock_count, host_workspace.data()); + + workspace.copy_from_host(host_workspace.data(), workspace_bytes); + + device_problem_sizes.reset(problem_count); + device_problem_sizes.copy_from_host(host_problem_sizes.data(), problem_count); + + visits_per_block = (tile_count - 1 + threadblock_count) / threadblock_count; + int32_t total_visits = visits_per_block * threadblock_count; + + visited_problems.reset(total_visits); + visited_tiles.reset(total_visits); + host_visited_problems.resize(total_visits); + host_visited_tiles.resize(total_visits); + + cudaError_t result = cudaMemset(visited_problems.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + result = cudaMemset(visited_tiles.get(), -1, sizeof(int32_t) * total_visits); + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + typename ProblemVisitor::Params pv_params(device_problem_sizes.get(), problem_count, workspace.get(), tile_count); + params = Params(pv_params, visited_problems.get(), visited_tiles.get(), visits_per_block); + + return cutlass::Status::kSuccess; + } + + bool verify() { + // Sort by problem size and then by threadblock_idx + std::vector indices(host_visited_problems.size()); + std::iota(indices.begin(), indices.end(), 0); + + std::stable_sort(indices.begin(), indices.end(), + [&](int32_t i1, int32_t i2) { + if (host_visited_problems[i1] == host_visited_problems[i2]) { + return host_visited_tiles[i1] < host_visited_tiles[i2]; + } + return host_visited_problems[i1] < host_visited_problems[i2]; + }); + + int32_t idx = 0; + + // Skip any entries that were not visited + while (host_visited_problems[indices[idx]] == -1) { + ++idx; + } + + // Check that each problem visited has the tiles we expect + for (int32_t problem_idx = 0; problem_idx < problem_count; ++problem_idx) { + auto problem = host_problem_sizes[problem_idx]; + ProblemVisitor::possibly_transpose_problem(problem); + int32_t problem_tiles = ProblemVisitor::tile_count(ProblemVisitor::grid_shape(problem)); + for (int i = 0; i < problem_tiles; ++i) { + EXPECT_EQ(problem_idx, host_visited_problems[indices[idx]]); + EXPECT_EQ(i, host_visited_tiles[indices[idx]]); + ++idx; + } + } + + return true; + } + + bool run(cudaStream_t stream = nullptr) { + cutlass::Status status = initialize(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Initialization failed" << std::endl; + return false; + } + + dim3 grid(threadblock_count, 1, 1); + dim3 block(ProblemVisitor::kThreadCount, 1, 1); + int smem_size = int(sizeof(typename BaseKernel::SharedStorage)); + + cutlass::Kernel<<>>(params); + + cudaError_t result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "grid launch failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + visited_problems.copy_to_host(host_visited_problems.data()); + visited_tiles.copy_to_host(host_visited_tiles.data()); + + return verify(); + } +}; + +template +struct TestbedGroupedGemmScheduler { + + using PSHelper = cutlass::gemm::kernel::detail::GemmGroupedProblemSizeHelper; + using BaselinePV = BaselineProblemVisitor; + + // + // Data members + // + uint32_t seed; + int problem_count; + int threadblock_count; + std::vector problem_sizes_host; + + // + // Methods + // + + TestbedGroupedGemmScheduler(uint32_t seed_ = 3080): + seed(seed_) { srand(seed); } + + /// Initializes data structures + void initialize(int32_t scale_factor) { + + // + // Choose random problem sizes + // + + problem_sizes_host.clear(); + problem_sizes_host.resize(problem_count); + + for (int32_t i = 0; i < problem_count; ++i) { + + cutlass::gemm::GemmCoord problem( + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24, + scale_factor * (rand() % 64) + 24); + + problem_sizes_host.at(i) = problem; + } + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + using PV = cutlass::gemm::kernel::GemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + Transpose>; + ProblemVisitorRunner runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(runner.run()); + + // Check that this problem visitor visits the same problems and tiles as the baseline + EXPECT_EQ(baseline_runner.host_visited_problems, runner.host_visited_problems); + EXPECT_EQ(baseline_runner.host_visited_tiles, runner.host_visited_tiles); + } + + template + void compare_visitors(const ProblemVisitorRunner& baseline_runner) { + // Compare the next visitor with the baseline visitor + compare_visitors(baseline_runner); + + // Recurse to compare the next visitors + compare_visitors(baseline_runner); + } + + /// Executes the test on all scheduler modes + void run(int problem_count, int threadblock_count, int scale_factor=8) { + + this->problem_count = problem_count; + this->threadblock_count = threadblock_count; + + // Initialize the problem + initialize(scale_factor); + + // Run the baseline visitor to which we will compare all other visitors + ProblemVisitorRunner baseline_runner(problem_sizes_host, threadblock_count); + EXPECT_TRUE(baseline_runner.run()); + + compare_visitors(baseline_runner); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // gemm +} // test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h new file mode 100644 index 0000000000000000000000000000000000000000..2a5956000db8e8c05ea22538e58149998b03e3fc --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_interleaved.h @@ -0,0 +1,346 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +struct InterleavedTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + // + // Methods + // + + InterleavedTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Waives test if CUDA device is insufficient + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm::ElementA, + typename Gemm::LayoutA> tensor_A(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementB, + typename Gemm::LayoutB> tensor_B_reordered(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_C(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> tensor_D(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm::ElementC, + typename Gemm::LayoutC> reference_D(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + cutlass::reorder_column( + tensor_B_reordered.host_ref(), tensor_B.host_ref(), problem_size); + + cutlass::reference::host::TensorCopy( + reference_D.host_view(), + tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B_reordered.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + tensor_A.device_ref(), + tensor_B_reordered.device_ref(), + tensor_C.device_ref(), + tensor_D.device_ref(), + {alpha, beta} + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.initialize(arguments); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D.host_view(), + tensor_D.host_view()); + + EXPECT_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nB_reordered =\n" << tensor_B_reordered.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Runs a set of problem sizes + bool run_all() { + bool passed = true; + + int problem_size_m[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_n[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + int problem_size_k[] = { + InterleavedK, 256 + InterleavedK, 512 + InterleavedK + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = run( + {m, n, k}, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return true; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..32452c30e05f64763a268195ae78138f26c09735 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_planar_complex.h @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace device { + +//////////////////////////////////////////////////////////////////////////////// + +template +class TestbedPlanarComplex { +public: + + using ElementA = typename Gemm::ElementA; + using LayoutA = typename Gemm::LayoutA; + using ElementB = typename Gemm::ElementB; + using LayoutB = typename Gemm::LayoutB; + using ElementC = typename Gemm::ElementC; + using LayoutC = typename Gemm::LayoutC; + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + // + // Data members + // + + cutlass::gemm::GemmCoord problem_size; + cutlass::HostTensorPlanarComplex tensor_A; + cutlass::HostTensorPlanarComplex tensor_B; + cutlass::HostTensorPlanarComplex tensor_C; + cutlass::HostTensorPlanarComplex tensor_D; + cutlass::HostTensorPlanarComplex tensor_D_ref; + + // + // Methods + // + + TestbedPlanarComplex(cutlass::gemm::GemmCoord const & problem_size): problem_size(problem_size) { + + tensor_A.reset({problem_size.m(), problem_size.k()}); + tensor_B.reset({problem_size.k(), problem_size.n()}); + tensor_C.reset({problem_size.m(), problem_size.n()}); + tensor_D.reset({problem_size.m(), problem_size.n()}); + tensor_D_ref.reset({problem_size.m(), problem_size.n()}, false); + } + + void initialize() { + + uint64_t seed = 1073; + + int scope_max = 8; + int scope_min = -8; + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed * 2019, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_C.host_view(), seed * 2020, scope_max, scope_min, 0); + + cutlass::reference::host::TensorFill(tensor_D.host_view(), cutlass::complex()); + cutlass::reference::host::TensorFill(tensor_D_ref.host_view(), cutlass::complex()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + bool run( + cutlass::complex alpha = {1, 0}, + cutlass::complex beta = {0, 0}) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + initialize(); + + int batch_count = 1; + + ElementA *ptr_A = tensor_A.device_data(); + ElementB *ptr_B = tensor_B.device_data(); + ElementC *ptr_C = tensor_C.device_data(); + ElementC *ptr_D = tensor_D.device_data(); + + typename LayoutA::Stride::Index lda = tensor_A.layout().stride(0); + typename LayoutB::Stride::Index ldb = tensor_B.layout().stride(0); + typename LayoutC::Stride::Index ldc = tensor_C.layout().stride(0); + typename LayoutC::Stride::Index ldd = tensor_D.layout().stride(0); + + int64_t imag_stride_A = tensor_A.imaginary_stride(); + int64_t imag_stride_B = tensor_B.imaginary_stride(); + int64_t imag_stride_C = tensor_C.imaginary_stride(); + int64_t imag_stride_D = tensor_D.imaginary_stride(); + + // + // Launch device kernel + // + + Gemm gemm_op; + + typename Gemm::Arguments args{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + {alpha, beta}, + ptr_A, + ptr_A + imag_stride_A, + ptr_B, + ptr_B + imag_stride_B, + ptr_C, + ptr_C + imag_stride_C, + ptr_D, + ptr_D + imag_stride_D, + lda, + lda, + ldb, + ldb, + ldc, + ldc, + ldd, + ldd + }; + + cutlass::Status status = gemm_op(args); + + EXPECT_EQ(status, cutlass::Status::kSuccess); + + cudaError_t error = cudaDeviceSynchronize(); + + tensor_D.sync_host(); + + // + // Compute reference + // + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + tensor_D_ref.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_D.host_view(), + tensor_D_ref.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("gemm_planar_complex.txt"); + + output + << "A:\n" << tensor_A.host_view() << "\n" + << "B:\n" << tensor_B.host_view() << "\n" + << "C:\n" << tensor_C.host_view() << "\n" + << "Reference:\n" + << tensor_D_ref.host_view() << "\n" + << "Computed:\n" + << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +template +bool TestOneGemmPlanarComplex(cutlass::gemm::GemmCoord problem_size) { + + TestbedPlanarComplex testbed(problem_size); + + return testbed.run(); +} + +template +bool TestAllGemmPlanarComplex() { + + int M[] = { + 16, 64, 72, 144, 264, 520, + }; + + int N[] = { + 16, 64, 72, 144, 248, 264, 520 + }; + + int K[] = { + 8, 64, 72, 96, 264, 520 + }; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + cutlass::complex alpha_values[] = { + {ElementCompute(1.25), ElementCompute(-0.5)} + }; + + cutlass::complex beta_values[] = { + {ElementCompute(-2.25), ElementCompute(1.5)} + }; + + for (int m : M) { + for (int n : N) { + for (int k : K) { + + test::gemm::device::TestbedPlanarComplex testbed({m, n, k}); + + for (auto const &alpha : alpha_values) { + for (auto const &beta : beta_values) { + + bool passed = testbed.run(alpha, beta); + if (!passed) { + return false; + } + } + } + } + } + } + + return true; +} + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..4d9f6743a45e5dc3a7b4ddd3e2a7b2abceffbb18 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank2k_universal.h @@ -0,0 +1,641 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_2k.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + using ElementCompute = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Rank2K::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Rank2K::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Rank2K workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Rank2K::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Rank2K::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Rank2K::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Rank2K + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename Rank2K::ElementA, typename Rank2K::LayoutA, + typename Rank2K::ElementB, typename Rank2K::LayoutB, + typename Rank2K::ElementC, typename Rank2K::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Rank2K::kTransformA, + tensor_B.host_ref(), + Rank2K::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Rank2K::Rank2Kkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRank2KUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Rank2K operator + // + + typename Rank2K::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Rank2K rank2k_op; + + size_t workspace_size = Rank2K::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Rank2K + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_Rank2k_device_" + << "fill_mode_c_" + << (Rank2K::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (Rank2K::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Rank2K::ThreadblockShape::kM << "x" + << Rank2K::ThreadblockShape::kN << "x" + << Rank2K::ThreadblockShape::kK << "_" + << Rank2K::WarpShape::kM << "x" + << Rank2K::WarpShape::kN << "x" + << Rank2K::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRank2KUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.25 + }; + + double problem_beta[] = { + 0.0, 2.15 + }; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +template +bool TestAllRank2KHermitianUniversal() { + bool passed = true; + + using ElementCompute = typename Rank2K::EpilogueOutputOp::ElementCompute; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Rank2K::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages - kAlignmentK, + Rank2K::ThreadblockShape::kK * Rank2K::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + /* Complex alpha for HER2K */ + ElementAccumulator problem_alpha[] = { + {1.0}, + {1.25, 3.25}, + {-0.25, -2.25} + }; + + ElementAccumulator problem_beta[] = { + 0.0, -2.25 + }; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + //if (k / batch_count < 2 * Rank2K::ThreadblockShape::kK) { + // continue; + //} + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + alpha, + beta + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..cb46528a049ae1254d0492b6235821210e47b957 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_rank_k_universal.h @@ -0,0 +1,511 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Rank 2k update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/rank_k_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedRank2KUniversal { + + using ElementA = typename RankK::ElementA; + using ElementC = typename RankK::ElementC; + using ElementAccumulator = typename RankK::ElementAccumulator; + using ElementCompute = typename RankK::RankKkernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedRank2KUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, RankK::kFillModeC, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, RankK::kFillModeC, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the RankK workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_symmetric_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename RankK::ElementA(1); + tensor_C.host_view().at({0, 0}) = typename RankK::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a RankK + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + cutlass::reference::host::Rank2KComplex< + typename RankK::ElementA, typename RankK::LayoutA, + typename RankK::ElementC, typename RankK::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + RankK::kTransformA, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0), + RankK::kFillModeC, + RankK::kBlasMode + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename RankK::RankKkernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedRankKUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the RankK operator + // + + typename RankK::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + RankK rank2k_op; + + size_t workspace_size = RankK::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = rank2k_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the RankK + // + + status = rank2k_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_RankK_device_" + << "fill_mode_c_" + << (RankK::kFillModeC == cutlass::FillMode::kLower ? "lower_" : + (RankK::kFillModeC == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << RankK::ThreadblockShape::kM << "x" + << RankK::ThreadblockShape::kN << "x" + << RankK::ThreadblockShape::kK << "_" + << RankK::WarpShape::kM << "x" + << RankK::WarpShape::kN << "x" + << RankK::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestRank2kUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedRank2KUniversal testbed; + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllRankKUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + int const kAlignmentN = 128 / kMinimumOperandElementSize; + int const kAlignmentK = 128 / kMinimumOperandElementSize; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages - kAlignmentK, + RankK::ThreadblockShape::kK * RankK::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename RankK::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + } + + cutlass::gemm::GemmCoord problem_size(n, n, k); + + TestbedRank2KUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h new file mode 100644 index 0000000000000000000000000000000000000000..0a01a6a32ee2db84f2e890059423cd6b8477f766 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sanity.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/core_io.h" + +#include "testbed.h" + + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// List of Gemm internal paramters this testbed supports user verification +// +enum class ParameterID { + + // Threadblock-level parameters + kSmemASize, + kSmemBSize, + + // Warp-level parameters + kWarpFragmentASize, + kWarpFragmentBSize, + kWarpFragmentCSize, + kInvalid +}; + +struct Reference { + ParameterID parameter_id; + + union { + int value; + + struct { + int m, n, k; + } gemm_shape; + + struct { + int row, column; + } matrix_shape; + }; + + std::string error_msg; + + Reference( + ParameterID parameter_id_, + int value_=-1, + std::string const &error_msg_="") : parameter_id(parameter_id_), value(value_), error_msg(error_msg_) {} +}; + + +template +struct TestbedSanity { + + // + // Type definitions (All Gemm types top down) + // + + // Unpacking Gemm types in the following order + // Kernel-level > Threadblock-level > Warp-level > Instruction-level + + // kernel-level cutlass Gemm + using GemmKernel = typename Gemm::GemmKernel; + + // + // Threadblock-level gemm types + // + using MmaThreadBlock = typename GemmKernel::Mma; + + // Threadblock-level gemm shape covering one stage + using ThreadblockShape = typename MmaThreadBlock::Shape; + + // Shared memory size covering all stages + using SmemShapeA = typename MmaThreadBlock::Base::SharedStorage::ShapeA; + using SmemPaddingA = typename MmaThreadBlock::Policy::SmemPaddingA; + using SmemShapeB = typename MmaThreadBlock::Base::SharedStorage::ShapeB; + using SmemPaddingB = typename MmaThreadBlock::Policy::SmemPaddingB; + + + /// Number of stages + static int const kStages = MmaThreadBlock::Base::kStages; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = MmaThreadBlock::kWarpGemmIterations; + + + // + // Warp-level gemm types + // + + // Warp-level gemm operator + using MmaWarp = typename MmaThreadBlock::Operator; + + // Warp-level gemm shape covering all kgroups + using WarpShape = typename MmaWarp::Shape; + + // Warp-level framents holding operands A & B operand and destination C + using WarpFragmentA = typename MmaWarp::FragmentA; + using WarpFragmentB = typename MmaWarp::FragmentB; + using WarpFragmentC = typename MmaWarp::FragmentC; + + // + // Instruction-level gemm types + // + + // Instruction-level gemm operator + using MmaInstruction = typename MmaWarp::Policy::Operator; + + // Instruction shape + using InstructionShape = typename MmaInstruction::Shape; + + // Instruction-level framents holding operands A & B operand and destination C + using InstructionFragmentA = typename MmaInstruction::FragmentA; + using InstructionFragmentB = typename MmaInstruction::FragmentB; + using InstructionFragmentC = typename MmaInstruction::FragmentC; + + // + // Testbed types + // + + // Vector of values holding user provided reference + using ReferenceVector = std::vector; + + // + // Data members + // + ReferenceVector references; + + // + // Methods + // + + TestbedSanity(ReferenceVector const &references_ = ReferenceVector()) : references(references_){ } + + // verify all parameter in ReferenceVector + bool verify() { + for(auto ref : references) + verify_parameter(ref); + return true; + } + + // verify parameter of type Reference + void verify_parameter(Reference const& ref) { + switch(ref.parameter_id) { + case ParameterID::kWarpFragmentASize : EXPECT_TRUE(WarpFragmentA::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentBSize : EXPECT_TRUE(WarpFragmentB::kElements == ref.value) << *this; break; + case ParameterID::kWarpFragmentCSize : EXPECT_TRUE(WarpFragmentC::kElements == ref.value) << *this; break; + } + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +// Overload output operators for TesbedSanity +/////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +template +std::ostream & operator<<(std::ostream &out, TestbedSanity const &test) { + + + out << "Gemm internal parameters" << std::endl + << " Threadblock-level parameters:" << std::endl + << " ThreadblockShape = " << typename TestbedSanity::ThreadblockShape() << std::endl + << " kStages = " << TestbedSanity::kStages << std::endl + << " kWarpGemmIterations = "<< TestbedSanity::kWarpGemmIterations << std::endl + <<" Shared memory sizes:" << std::endl + <<" SmemPaddingA = " << typename TestbedSanity::SmemPaddingA() << std::endl + <<" SmemPaddingB = " << typename TestbedSanity::SmemPaddingB() << std::endl + <<" SmemShapeA = " << typename TestbedSanity::SmemShapeA() << std::endl + <<" SmemShapeB = " << typename TestbedSanity::SmemShapeB() << std::endl + <<" Warp-level parameters" << std::endl + <<" WarpShape = " << typename TestbedSanity::WarpShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" WarpFragmentA::kElements = " << TestbedSanity::WarpFragmentA::kElements << std::endl + <<" WarpFragmentB::kElements = " << TestbedSanity::WarpFragmentB::kElements << std::endl + <<" WarpFragmentC::kElements = " << TestbedSanity::WarpFragmentC::kElements << std::endl + <<" Instruction-level parameters" << std::endl + <<" InstructionShape = " << typename TestbedSanity::InstructionShape() << std::endl + <<" Fragment sizes:" << std::endl + <<" InstructionFragmentA::kElements = " << TestbedSanity::InstructionFragmentA::kElements << std::endl + <<" InstructionFragmentB::kElements = " << TestbedSanity::InstructionFragmentB::kElements << std::endl + <<" InstructionFragmentC::kElements = " << TestbedSanity::InstructionFragmentC::kElements << std::endl; + + return out; +} + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h new file mode 100644 index 0000000000000000000000000000000000000000..a95bf996bac337b44da616dc9fbf9c9bdb2a625c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_sparse.h @@ -0,0 +1,487 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + + Testbed for sparse operations not to be released for CUDA 11.0 GA. Expected release is 11.1. +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SparseTestbed { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + static int const kSparse = Gemm::GemmKernel::kSparse; + static int const kMetaSizeInBits = Gemm::GemmKernel::kMetaSizeInBits; + static int const kMaxID2 = Gemm::GemmKernel::kMaxID2; + static int const kElementsPerElementE = Gemm::GemmKernel::kElementsPerElementE; + + using ElementE = typename Gemm::GemmKernel::ElementE; + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename Gemm::GemmKernel::LayoutE; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_E; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + SparseTestbed( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080) + : init_A(init_A_), + init_B(init_B_), + init_C(init_C_), + init_E(init_E_), + seed(seed_) {} + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + tensor_A.resize(cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); + tensor_A_uncompressed.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + tensor_E.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + tensor_E_reordered.resize(cutlass::make_Coord( + problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, kMetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (kMaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + EXPECT_TRUE(false); + } + + cutlass::reorder_meta(tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Gemm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Gemm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_E_reordered.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nE =\n" << tensor_E.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), problem_size.m(), problem_size.k()); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + problem_size, + alpha, + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + beta, + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_E_reordered.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + tensor_E_reordered.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + // This failure is likely due to insufficient device capabilities. Waive the test. + if (status != cutlass::Status::kSuccess) { + return true; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Error with split_k_slices = " << split_k_slices << ", alpha: " << alpha << ", beta: " << beta << ", m: " << problem_size.m() << ", n: " << problem_size.n() << ", k:" < +bool TestAllSparseGemm() { + bool passed = true; + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + int const kAlignmentM = std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int problem_size_m[] = {kAlignmentM, 512 - 3 * kAlignmentM}; + + int problem_size_n[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + + int problem_size_k[] = {Gemm::ThreadblockShape::kK * 8}; + + int split_k_slices[] = { + 1, 2 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + SparseTestbed testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int split_k : split_k_slices) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + cutlass::gemm::GemmCoord problem_size(m, n, k); + + passed = testbed.run( + problem_size, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h new file mode 100644 index 0000000000000000000000000000000000000000..8fa4a85505316d08f1d050702b78448f8fae8565 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_splitk.h @@ -0,0 +1,218 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "testbed.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSplitK : public Testbed { + + using Base = Testbed; + + using ElementCompute = typename Base::ElementCompute; + + // + // Methods + // + + TestbedSplitK( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + Base(init_A_, init_B_, init_C_, seed_) { } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + int split_k_slices, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + problem_size, + this->tensor_A.device_ref(), + this->tensor_B.device_ref(), + this->tensor_C.device_ref(), + this->tensor_D.device_ref(), + {alpha, beta}, + split_k_slices + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess); + + // + // Verify + // + + return this->verify(problem_size, alpha, beta); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +bool TestAllGemmSplitK() { + bool passed = true; + + cutlass::gemm::GemmCoord problem_sizes[] = { + {8, 8, 2048}, + {8, 8, 2056}, + {264, 72, 520}, + {264, 520, 120}, + {264, 520, 264} + }; + + int split_k_slices[] = { + 1, 2, 4, 5, 7 + }; + + double problem_alpha[] = { + 0.5 + }; + + double problem_beta[] = { + 2.0 + }; + + using Testbed = TestbedSplitK; + using ElementCompute = typename Testbed::ElementCompute; + + Testbed testbed; + + for (auto problem_size : problem_sizes) { + for (int split_k_count : split_k_slices) { + for (double alpha : problem_alpha) { + for (double beta : problem_beta) { + + passed = testbed.run( + problem_size, + split_k_count, + ElementCompute(alpha), + ElementCompute(beta) + ); + + if (!passed) { + std::cout << "Failed on size " << problem_size << " with split_k_count " << split_k_count << std::endl; + return false; + } + } + } + } + } + + EXPECT_TRUE(passed); + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a57f7eb0ca73c23460e5a9ce1301061c2cc286 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_symm_universal.h @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide Symm update interface + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/symm.h" +#include "cutlass/util/reference/host/symm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedSymmUniversal { + + using ElementA = typename Symm::ElementA; + using ElementB = typename Symm::ElementB; + using ElementC = typename Symm::ElementC; + using ElementAccumulator = typename Symm::ElementAccumulator; + using ElementCompute = typename Symm::SymmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedSymmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + + EXPECT_TRUE(false) << "Input distribution not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Symm::kFillModeA, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Symm::kFillModeA, 0, 0.5, mantissa_in_bits); + } + else { + + EXPECT_TRUE(false) << "Input distribution (symmetric tensor) not implemented"; + return false; + } + + return true; + } + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the Symm workspace + // + + if (Symm::kSideModeA == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Symm::kSideModeA == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2019, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Symm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Symm::ElementB(1); + tensor_C.host_view().at({0, 0}) = typename Symm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a Symm + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::SymmComplex< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator, + Symm::kBlasMode>, + cutlass::reference::host::Symm< + typename Symm::ElementA, typename Symm::LayoutA, + Symm::kSideModeA, Symm::kFillModeA, + typename Symm::ElementB, typename Symm::LayoutB, + typename Symm::ElementC, typename Symm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_symm; + + reference_symm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Symm::SymmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedSymmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) + << " beta: " << ElementCompute(beta) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the Symm operator + // + + int batch_stride_A; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Symm::kSideModeA == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Symm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Symm symm_op; + + size_t workspace_size = Symm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = symm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the Symm + // + + status = symm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + //if (true) { + if (!passed) { + std::stringstream fname; + + fname << "error_" + << (Symm::kBlasMode == cutlass::BlasMode::kSymmetric ? "symm_" : "hemm_" ) + << "device_" + << "fill_mode_a_" + << (Symm::kSideModeA == cutlass::SideMode::kLeft ? "leftside_" : + (Symm::kSideModeA == cutlass::SideMode::kRight ? "rightside_" : "invalid_")) + << (Symm::kFillModeA == cutlass::FillMode::kLower ? "lower_" : + (Symm::kFillModeA == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Symm::ThreadblockShape::kM << "x" + << Symm::ThreadblockShape::kN << "x" + << Symm::ThreadblockShape::kK << "_" + << Symm::WarpShape::kM << "x" + << Symm::WarpShape::kN << "x" + << Symm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "alpha: " << ElementCompute(alpha) << "\n" + << "beta: " << ElementCompute(beta) << "\n" + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nC:\n" << tensor_C.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestsymmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedSymmUniversal testbed; + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllSymmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Symm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages - kAlignmentK, + Symm::ThreadblockShape::kK * Symm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 3.0 + }; + + double problem_beta[] = { + 0, 2.0 + }; + + + using ElementCompute = typename Symm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + int k = 0; + if (Symm::kSideModeA == cutlass::SideMode::kLeft) + k = m; + else if (Symm::kSideModeA == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + #if 0 + // skip very small K problems + if (k / batch_count < 2 * Symm::ThreadblockShape::kK) { + continue; + } + #endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedSymmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..b30acfed6bba547986efd3afa8eb829be2a255e4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_trmm_universal.h @@ -0,0 +1,606 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide TRMM interface + + +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/trmm.h" +#include "cutlass/util/reference/host/trmm_complex.h" +#include "cutlass/core_io.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedTrmmUniversal { + + using ElementA = typename Trmm::ElementA; + using ElementB = typename Trmm::ElementB; + using ElementC = typename Trmm::ElementC; + using ElementAccumulator = typename Trmm::ElementAccumulator; + using ElementCompute = typename Trmm::TrmmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_D; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedTrmmUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_D_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_D(init_D_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + + /// Helper to initialize a tensor view + template + bool initialize_symmetric_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int mantissa_in_bits) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillSymmetricRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, mantissa_in_bits); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillSymmetricRandomGaussian( + view, seed, Trmm::kFillMode, 0, 0.5, mantissa_in_bits); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Helper to initialize a tensor view (pad diagonal fill with zeros for up to alignment on wrong side of diagonal) + template + bool initialize_pad_diagonal_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed, + int alignment) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillPadDiagonalRandomUniform( + view, seed, Trmm::kFillMode, scope_max, scope_min, 0, alignment); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + EXPECT_TRUE(false) << "Gaussian distribution for pad diagonal not implemented"; + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the TRMM workspace + // + + if (Trmm::kSideMode == cutlass::SideMode::kLeft) { + tensor_A.resize(cutlass::make_Coord(problem_size.m(),problem_size.m())); + } + else if (Trmm::kSideMode == cutlass::SideMode::kRight) { + tensor_A.resize(cutlass::make_Coord(problem_size.n(),problem_size.n())); + } + + tensor_B.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + //EXPECT_TRUE(initialize_symmetric_tensor(tensor_A.host_view(), init_A, seed + 2017)); + //EXPECT_TRUE(initialize_pad_diagonal_tensor(tensor_A.host_view(), init_A, seed + 2017, Trmm::kAlignmentA)); + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2017, cutlass::MantissaInBits::bits)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2019, cutlass::MantissaInBits::bits)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + tensor_A.host_view().at({0, 0}) = typename Trmm::ElementA(1); + tensor_B.host_view().at({0, 0}) = typename Trmm::ElementB(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_D.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + + if (tensor_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + + if (reference_D.size() > 1) + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + double l2_norm = cutlass::reference::host::TensorRelativeErrorMetric(reference_D.host_view(), tensor_D.host_view()); + + bool passed = l2_norm < cutlass::MantissaInBits::error; + + return passed; + } + + /// Verifies the result is a TRMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha) { + + // + // Verify + // + + using HostReference = typename cutlass::platform::conditional< + (cutlass::platform::is_same + >::value || + cutlass::platform::is_same + >::value + ), + cutlass::reference::host::TrmmComplex< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kTransformA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + Trmm::kTransformB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator>, + cutlass::reference::host::Trmm< + typename Trmm::ElementA, typename Trmm::LayoutA, + Trmm::kSideMode, Trmm::kFillMode, Trmm::kDiagType, + typename Trmm::ElementB, typename Trmm::LayoutB, + typename Trmm::ElementC, typename Trmm::LayoutC, + ElementCompute, + ElementAccumulator> + >::type; + + + HostReference reference_trmm; + + reference_trmm( + problem_size, + alpha, + tensor_A.host_ref(), + tensor_B.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + return compare_reference(problem_size, alpha); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Trmm::TrmmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1)) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + +#if 0 + std::cout << "[TestbedTrmmUniversal::run()] problem(m, n, k): " << problem_size + << " alpha: " << ElementCompute(alpha) << std::endl; +#endif + + this->initialize(problem_size); + + // + // Initialize the TRMM operator + // + + int batch_stride_A; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + batch_stride_A = problem_size.m()*problem_size.m(); + if (Trmm::kSideMode == cutlass::SideMode::kRight) + batch_stride_A = problem_size.n()*problem_size.n(); + + typename Trmm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_D.device_data(), + batch_stride_A, + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Trmm trmm_op; + + size_t workspace_size = Trmm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = trmm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the TRMM + // + + status = trmm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + bool passed = this->verify(problem_size, alpha); + + if (!passed) { + std::stringstream fname; + + fname << "error_Trmm_device_" + << "fill_mode_" + << (Trmm::kFillMode == cutlass::FillMode::kLower ? "lower_" : + (Trmm::kFillMode == cutlass::FillMode::kUpper ? "upper_" : "invalid_")) + << "side_mode_" + << (Trmm::kSideMode == cutlass::SideMode::kLeft ? "left_" : + (Trmm::kSideMode == cutlass::SideMode::kRight ? "right_" : "invalid_")) + << "mnk_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Trmm::ThreadblockShape::kM << "x" + << Trmm::ThreadblockShape::kN << "x" + << Trmm::ThreadblockShape::kK << "_" + << Trmm::WarpShape::kM << "x" + << Trmm::WarpShape::kN << "x" + << Trmm::WarpShape::kK << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results << problem_size << std::endl; + + results + << "\nA:\n" << tensor_A.host_view() << "\n" + << "\nB:\n" << tensor_B.host_view() << "\n" + << "\nD reference:\n" << reference_D.host_view() << "\n" + << "\nD computed:\n" << tensor_D.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestTrmmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0) { + + bool passed = true; + + TestbedTrmmUniversal testbed; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + return passed; +} + +template +bool TestAllTrmmUniversal() { + bool passed = true; + + int const kMinimumOperandElementSize = int(cutlass::sizeof_bits::value); + + int const kAlignment = cutlass::platform::is_same< + typename Trmm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = kAlignmentM; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value + ? 4 : kAlignment; + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages - kAlignmentK, + Trmm::ThreadblockShape::kK * Trmm::kStages * 3 - kAlignmentK + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1 // Just running one batch for now (removing 2, 3, 5, 7) + }; + + double problem_alpha[] = { + 1.0, 2.0 + }; + + using ElementCompute = typename Trmm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int batch_count : batch_counts) { + for (auto alpha : problem_alpha) { + + int k = 0; + if (Trmm::kSideMode == cutlass::SideMode::kLeft) + k = m; + else if (Trmm::kSideMode == cutlass::SideMode::kRight) + k = n; + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + +#if 0 + // skip very small K problems + if (k / batch_count < 2 * Trmm::ThreadblockShape::kK) { + continue; + } +#endif + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedTrmmUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h new file mode 100644 index 0000000000000000000000000000000000000000..00368a5e8eebc128719f64069583010c83dc0c1f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_universal.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" + +#include "testbed_utils.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct TestbedUniversal { + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D; + cutlass::HostTensor reference_D; + + // + // Methods + // + + TestbedUniversal( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + bool is_unsigned_int = std::numeric_limits::is_integer && !std::numeric_limits::is_signed; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = is_unsigned_int ? 2 : 1; + scope_min = is_unsigned_int ? 0 : -1; + } else if (bits_output == 16) { + constexpr auto u8_bf16 = + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value) || + (cutlass::platform::is_same::value && + cutlass::platform::is_same::value); + scope_max = is_unsigned_int ? 10 : (u8_bf16 ? 3 : 5); + scope_min = is_unsigned_int ? 0 : (u8_bf16 ? -3 : -5); + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(problem_size.mk()); + tensor_B.resize(problem_size.kn()); + tensor_C.resize(problem_size.mn()); + tensor_D.resize(problem_size.mn()); + reference_D.resize(problem_size.mn(), false); + + EXPECT_TRUE(initialize_tensor(tensor_A.host_view(), init_A, seed + 2019)); + EXPECT_TRUE(initialize_tensor(tensor_B.host_view(), init_B, seed + 2018)); + EXPECT_TRUE(initialize_tensor(tensor_C.host_view(), init_C, seed + 2017)); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + /* + + std::stringstream fname; + + fname << "error_Gemm_device_" + << problem_size.m() << "x" + << problem_size.n() << "x" + << problem_size.k() << "_" + << Gemm::ThreadblockShape::kM << "x" + << Gemm::ThreadblockShape::kN << "x" + << Gemm::ThreadblockShape::kK << "_" + << Gemm::WarpShape::kM << "x" + << Gemm::WarpShape::kN << "x" + << Gemm::WarpShape::kK << ".txt"; + + std::ofstream file(fname.str()); + */ + + std::ofstream file("testbed_universal_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\n\nReference =\n" << reference_D.host_view() + << "\nComputed =\n" << tensor_D.host_view(); + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size, + alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + beta, + tensor_C.host_ref(), + reference_D.host_ref(), + ElementAccumulator(0) + ); + + if (Relu) { + for (int i = 0; i < problem_size.m(); ++i) { + for (int j = 0; j < problem_size.n(); ++j) { + reference_D.at(cutlass::MatrixCoord(i, j)) = + ((ElementCompute)reference_D.at(cutlass::MatrixCoord(i, j)) < (ElementCompute)0) + ? (typename Gemm::ElementC)0 + : reference_D.at(cutlass::MatrixCoord(i, j)); + } + } + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { +/* + std::cout << "\n-----------------------\n"; + std::cout << "mode: " << (int) mode << "\n"; + std::cout << "problem size: " << problem_size << "\n"; + std::cout << "batch_count: " << batch_count << "\n"; + std::cout << "alpha: " << alpha << "\n"; + std::cout << "beta: " << beta << "\n"; + std::cout << "-----------------------\n\n"; +*/ + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::Arguments arguments{ + mode, + problem_size, + batch_count, + {alpha, beta}, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0) + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +bool TestGemmUniversal( + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmUniversalMode mode, + int batch_count, + double alpha = 1.0, + double beta = 2.0) { + + bool passed = true; + + TestbedUniversal testbed; + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + return passed; +} + +template +bool TestAllGemmUniversal() { + bool passed = true; + + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int const kAlignment = cutlass::platform::is_same< + typename Gemm::OperatorClass, + cutlass::arch::OpClassSimt>::value ? 1 : 128 / kMinimumOperandElementSize; + + // int8_t gemm alignment constraints + int const kAlignmentM = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentN = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value ? 4 : kAlignment; + + int const kAlignmentK = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + (cutlass::platform::is_same::value || + cutlass::platform::is_same::value) ? 4 : kAlignment; + + + + cutlass::gemm::GemmUniversalMode modes[] = { + cutlass::gemm::GemmUniversalMode::kGemm, + }; + + int problem_size_m[] = { + kAlignmentM, 512 - 3*kAlignmentM + }; + + int problem_size_n[] = { + kAlignmentN, 512 - 2*kAlignmentN + }; + + int problem_size_k[] = { + kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages - kAlignmentK, + Gemm::ThreadblockShape::kK * Gemm::kStages * 3 - kAlignmentK + }; + + int batch_counts[] = { // may be interpretted as batch count or split-K slices + 1, 2, 3, 5, 7 + }; + + double problem_alpha[] = { + 1 + }; + + double problem_beta[] = { + 2.0 + }; + + + using ElementCompute = typename Gemm::EpilogueOutputOp::ElementCompute; + + for (cutlass::gemm::GemmUniversalMode mode : modes) { + for (int m : problem_size_m) { + for (int n : problem_size_n) { + for (int k : problem_size_k) { + for (int batch_count : batch_counts) { + + for (auto alpha : problem_alpha) { + for (auto beta : problem_beta) { + + if (mode == cutlass::gemm::GemmUniversalMode::kGemm || + mode == cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel) { + + // skip very small K problems + if (k / batch_count < 2 * Gemm::ThreadblockShape::kK) { + continue; + } + } + + cutlass::gemm::GemmCoord problem_size(m, n, k); + + TestbedUniversal testbed; + + passed = testbed.run( + mode, + problem_size, + batch_count, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + if (!passed) { + return false; + } + } + } + } + } + } + } + } + + /* + // large problem with high coverage + for (int split_k_slices = 1; split_k_slices <= 3; ++split_k_slices) { + TestbedUniversal testbed; + + cutlass::gemm::GemmCoord problem_size(72, 56, 8192); + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + split_k_slices, + cutlass::from_real(1.0), + cutlass::from_real(2.0) + ); + + if (!passed) { + break; + } + } + */ + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..89ac33a1028061515d08d50fdb6cce7833ae88ce --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_utils.h @@ -0,0 +1,53 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +inline char const *to_string(cutlass::Status status) { + + switch (status) { + case cutlass::Status::kSuccess: return "kSuccess"; + case cutlass::Status::kErrorMisalignedOperand: return "kErrorMisalignedOperand"; + case cutlass::Status::kErrorInvalidLayout: return "kErrorInvalidLayout"; + case cutlass::Status::kErrorInvalidProblem: return "kErrorInvalidProblem"; + case cutlass::Status::kErrorNotSupported: return "kErrorNotSupported"; + case cutlass::Status::kErrorWorkspaceNull: return "kErrorWorkspaceNull"; + case cutlass::Status::kErrorInternal: return "kErrorInternal"; + case cutlass::Status::kInvalid: return "kInvalid"; + default: break; + } + return "invalid"; +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h new file mode 100644 index 0000000000000000000000000000000000000000..8b5588f57c40c4e8f8d06adfa9f1e673350fb5e5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/device/testbed_with_absmax.h @@ -0,0 +1,609 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed for running device-level GEMMs with absolute maximum calculation and scaling +*/ + +#pragma once + +#include +#include +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "testbed.h" +#include "testbed_sparse.h" +#include "testbed_utils.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" + +namespace test { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor +> +struct TestbedWithAmax { + + static_assert(std::is_same_v> || std::is_same_v>); + static constexpr bool IsSparseTestbed = std::is_same_v>; + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; + using ElementAbsmax = typename Gemm::EpilogueOutputOp::ElementAbsmax; + + static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + bool doScaleA; + bool doScaleB; + bool doScaleC; + + GemmTestbed underlying_testbed; + + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedWithAmax( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform + ): + doScaleA(scaleA), doScaleB(scaleB), doScaleC(scaleC), + underlying_testbed(init_A_, init_B_, init_C_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Initializes data structures + void initialize(cutlass::gemm::GemmCoord problem_size) { + // + // Allocate the GEMM workspace + // + underlying_testbed.initialize(problem_size); + + tensor_Vector.resize({1, problem_size.n()}); + reference_D.resize(problem_size.mn(), false); + tmp_D.resize(problem_size.mn(), false); + + EXPECT_TRUE( + underlying_testbed.initialize_tensor(tensor_Vector.host_view(), underlying_testbed.init_C, underlying_testbed.seed + 2020) + ); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorCopy(reference_D.host_view(), underlying_testbed.tensor_C.host_view()); + + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (doScaleA) { + scale_A.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_A.host_view(), underlying_testbed.seed + 2021, scale_bits)); + scale_A.sync_device(); + } + + if (doScaleB) { + scale_B.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_B.host_view(), underlying_testbed.seed + 2022, scale_bits)); + scale_B.sync_device(); + } + + if (doScaleC) { + scale_C.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_C.host_view(), underlying_testbed.seed + 2023, scale_bits)); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_D.host_view(), underlying_testbed.seed + 2024, scale_bits)); + scale_D.sync_device(); + + abs_max_D.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(problem_size.mn()); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1}); + EXPECT_TRUE(initialize_scale_factor(scale_Aux.host_view(), underlying_testbed.seed + 2025, scale_bits)); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(problem_size.mn(), false); + reference_abs_max_Aux.resize({1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + underlying_testbed.tensor_D.sync_host(); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_A.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_B.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_C.host_view()), 0); + + EXPECT_GT(cutlass::reference::host::TensorNorm(underlying_testbed.tensor_D.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_D.host_view()), 0); + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), underlying_testbed.tensor_D.host_view()); + if (!passed) { + std::cout << "Comparison of D failed" << std::endl; + } + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_Aux.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux failed" << std::endl; + } + if (!cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view())) { + passed = false; + std::cout << "Comparison of Aux absmax failed" << std::endl; + } + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(abs_max_D.host_view()), 0); + if (!cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view())) { + passed = false; + std::cout << "Comparison of D absmax failed" << std::endl; + } + } + + EXPECT_TRUE(passed) << " mismatched reference"; + + if (!passed) { + + std::ofstream file("testbed_with_amax_errors.txt"); + + file + << "problem: " << problem_size + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; + + file + << "A =\n" << underlying_testbed.tensor_A.host_view() + << "\nB =\n" << underlying_testbed.tensor_B.host_view() + << "\nC =\n" << underlying_testbed.tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << "\n\nReference D =\n" << reference_D.host_view() + << "\nComputed D =\n" << underlying_testbed.tensor_D.host_view(); + if (kScaleAux) { + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\nComputed Aux =\n" << tensor_Aux.host_view() + << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() + << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); + } + if (kScaleOutput) { + file + << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() + << "\nComputed Absmax D = " << abs_max_D.host_view(); + } + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha, + ElementCompute beta) { + + cutlass::Coord<2> origin(0); + ElementCompute scaled_alpha = alpha; + if (doScaleA) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (doScaleB) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = beta; + if (doScaleC) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + auto ref_tA = [&](){ + if constexpr (IsSparseTestbed) { + cutlass::uncompress( + underlying_testbed.tensor_A_uncompressed.host_ref(), + underlying_testbed.tensor_A.host_ref(), + underlying_testbed.tensor_E.host_ref(), + problem_size.m(), + problem_size.k() + ); + return underlying_testbed.tensor_A_uncompressed.host_ref(); + } + else { + return underlying_testbed.tensor_A.host_ref(); + } + }(); + + // Run reference kernel with ElementOutput of type ElementAccumulator + // so that we can compute the absmax epilogue on data that is of type + // ElementAccumulator (which is what the GEMM we are testing will do). + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + problem_size, + scaled_alpha, + ref_tA, + Gemm::kTransformA, + underlying_testbed.tensor_B.host_ref(), + Gemm::kTransformB, + scaled_beta, + underlying_testbed.tensor_C.host_ref(), + tmp_D.host_ref(), + ElementAccumulator(0) + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_absmax; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + ActivationFunctor act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int m = 0; m < problem_size.m(); ++m) { + for (int n = 0; n < problem_size.n(); ++n) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_absmax(tmp_abs_max_D); + } + + return compare_reference(problem_size, alpha, beta); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + return underlying_testbed.sufficient(); + } + + /// Executes one test + bool run( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord problem_size, + int batch_count = 1, + ElementCompute alpha = ElementCompute(1), + ElementCompute beta = ElementCompute(0)) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + this->initialize(problem_size); + + // + // Initialize the GEMM operator + // + + typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{alpha, beta}; + typename Gemm::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + auto arguments = [&]() { + if constexpr (IsSparseTestbed) { + return typename Gemm::Arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + underlying_testbed.tensor_E_reordered.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + underlying_testbed.tensor_E_reordered.layout().stride(0), + tensor_Aux.layout().stride(0), + 0 // stride vector + }; + } + else { + return typename Gemm::Arguments{ + mode, + problem_size, + batch_count, + epilogue_params, + underlying_testbed.tensor_A.device_data(), + underlying_testbed.tensor_B.device_data(), + underlying_testbed.tensor_C.device_data(), + underlying_testbed.tensor_D.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + problem_size.m() * problem_size.k(), + problem_size.n() * problem_size.k(), + problem_size.m() * problem_size.n(), + problem_size.m() * problem_size.n(), + 0, // stride vector + underlying_testbed.tensor_A.layout().stride(0), + underlying_testbed.tensor_B.layout().stride(0), + underlying_testbed.tensor_C.layout().stride(0), + underlying_testbed.tensor_D.layout().stride(0), + (int64_t)0 // Leading dimension of vector. This must be 0 + }; + } + }(); + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + // + // Run the GEMM + // + + status = gemm_op(); + + EXPECT_TRUE(status == cutlass::Status::kSuccess) << to_string(status); + + cudaError_t cuda_error = cudaDeviceSynchronize(); + EXPECT_TRUE(cuda_error == cudaSuccess) << cudaGetErrorString(cuda_error); + + // + // Verify + // + + bool passed = this->verify(problem_size, alpha, beta); + + if (!passed) { + std::cout << "Failed with batch_count/split_k_slices = " << batch_count << std::endl; + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Gemm, + typename GemmTestbed, + template class ActivationFunctor = cutlass::epilogue::thread::Identity +> +bool TestAllGemmWithAbsmax(bool scaleA=true, bool scaleB=true, bool scaleC=true) { + + int const kMinimumOperandElementSize = + std::min( + int(cutlass::sizeof_bits::value), + int(cutlass::sizeof_bits::value)); + + int constexpr kAlignmentM = [&]() { + if constexpr (std::is_same_v>) { + // M dimension has to be multiple of 32 (sparse float) or 16 (sparse int) + // because of the reordering of operand E + return std::max(((sizeof(typename Gemm::ElementE) == 2) ? 32 : 16), + kMinimumOperandElementSize); + } + else { + return 128 / kMinimumOperandElementSize; + } + }(); + + int const kAlignmentN = 128 / kMinimumOperandElementSize; + + int M_problems[] = {kAlignmentM, 128 + 32}; + int N_problems[] = {kAlignmentN, 512 - 2 * kAlignmentN}; + int K_problems[] = {Gemm::ThreadblockShape::kK * 2}; + double alpha_problems[] = {1.}; + double beta_problems[] = {0.}; + int split_k_slices[] = { + 1, 2 + }; + + bool passed = true; + + for (int M : M_problems) { + for (int N : N_problems) { + for (int K : K_problems) { + for (int split_k : split_k_slices) { + if (cutlass::sizeof_bits_v <= 8 && split_k > 1) { + // Don't test split-K with FP8 output. The kernel being tested will writie partial accumulations + // for different splits to global memory in FP8, while the reference kernel will not. This leads + // to mismatches that are difficult to capture without a permissive relative equality check threshold. + continue; + } + + for (double alpha : alpha_problems) { + for (double beta : beta_problems) { + TestbedWithAmax testbed(scaleA, scaleB, scaleC); + + using ElementAccumulator = typename Gemm::ElementAccumulator; + + passed = testbed.run( + cutlass::gemm::GemmUniversalMode::kGemm, + {M, N, K}, + split_k, + cutlass::from_real(alpha), + cutlass::from_real(beta) + ); + + EXPECT_TRUE(passed) + << "M: " << M << ", N: " << N << ", K: " << K << ", alpha: " << alpha << ", beta: " << beta << ", split_k:" << split_k; + + if (!passed) { + + return passed; + } + } + } + } + } + } + } + + return passed; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace gemm +} // namespace test + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h new file mode 100644 index 0000000000000000000000000000000000000000..8e939f9710403a5f5c3fd8c61e34c4e8021ff423 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/kernel/testbed_gemv.h @@ -0,0 +1,358 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/core_io.h" +#include "cutlass/numeric_types.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "cutlass/gemm/kernel/default_gemv.h" +#include "cutlass/gemm/kernel/gemv_batched_strided.h" + +namespace test { +namespace gemm { +namespace kernel { + +template +void batched_gemv_kernel_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + bool perf_test = false, + int perf_test_iter = 1) +{ + using ThreadBlockShape = ThreadBlockShape_; + using ThreadShape = ThreadShape_; + using ElementA = ElementAB_; + using LayoutA = LayoutA_; + using ElementB = ElementAB_; + using LayoutB = LayoutB_; + using ElementAccumulator = ElementCD_; + using ElementCD = ElementCD_; + using LayoutCD = LayoutCD_; + + using GemvKernel = cutlass::gemm::kernel::DefaultGemv; + + using ThreadBlockGemv = typename GemvKernel::ThreadBlockGemv; + using ThreadBlockSwizzle = typename GemvKernel::ThreadBlockSwizzle; + + if (DEBUG) + { + problem_size = cutlass::gemm::BatchedGemmCoord( + problem_size.m(), problem_size.n(), problem_size.k(), 1); + } + + // Create host tensors that will be the backing store for the batches + // Note that no device memory is initially allocated + cutlass::HostTensor matrix_A({problem_size.m(), problem_size.k()}, false); + cutlass::HostTensor matrix_B({problem_size.k(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_computed({problem_size.m(), problem_size.n()}, false); + cutlass::HostTensor matrix_C_reference({problem_size.m(), problem_size.n()}, false); + + // Reserve memory for the batch of tensors + matrix_A.reserve(problem_size.m()*problem_size.k()*problem_size.batch()); + matrix_B.reserve(problem_size.n()*problem_size.k()*problem_size.batch()); + matrix_C_computed.reserve(problem_size.m()*problem_size.n()*problem_size.batch()); + matrix_C_reference.reserve(problem_size.m()*problem_size.n()*problem_size.batch(), false); + + // Fill eatch tensor batch + const int seed = 9876; + for (int b = 0; b < problem_size.batch(); b++) + { + if(DEBUG) + { + cutlass::reference::host::BlockFillSequential( + matrix_A.host_data_ptr_offset(b*matrix_A.capacity()), matrix_A.capacity()); + cutlass::reference::host::BlockFillSequential( + matrix_B.host_data_ptr_offset(b*matrix_B.capacity()), matrix_B.capacity()); + } + else + { + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(b*matrix_A.capacity()), + seed + 1660, + 8, + -8, + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(b*matrix_B.capacity()), + seed + 1880, + 8, + -8, + 0 + ); + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view(b*matrix_C_computed.capacity())); + cutlass::reference::host::TensorFill(matrix_C_reference.host_view(b*matrix_C_reference.capacity())); + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + ThreadBlockSwizzle swizzle; + + cutlass::gemm::BatchedGemmCoord tiled_size{ThreadBlockShape::kM, + ThreadBlockShape::kN, + problem_size.k(), // no split-k + DEBUG ? 1 : THREAD_B }; + + cutlass::gemm::BatchedGemmCoord tiled_shape = swizzle.get_tiled_shape(problem_size, tiled_size); + + #if 0 + printf("tiled_size = %d %d %d %d\n", tiled_size.m(), tiled_size.n(), tiled_size.k(), tiled_size.batch()); + printf("tiled_shape = %d %d %d %d\n", tiled_shape.m(), tiled_shape.n(), tiled_shape.k(), tiled_shape.batch()); + #endif + + // No split-k + EXPECT_EQ(tiled_size.k(), problem_size.k()); + + dim3 grid = swizzle.get_grid_shape(tiled_shape); + dim3 block(tiled_size.n() / ThreadShape::kN, tiled_size.batch(), tiled_size.k() / problem_size.k()); + + // Some sanity checks + EXPECT_TRUE( block.x*block.y*block.z <= 1024 ); + EXPECT_TRUE( block.x <= 1024 ); + EXPECT_TRUE( block.y <= 1024 ); + EXPECT_TRUE( block.z <= 64 ); + + #if 0 + printf("grid dim = %d, %d, %d\n", grid.x, grid.y, grid.z); + printf("block dim = %d, %d, %d\n", block.x, block.y, block.z); + #endif + + cudaError_t result; + cudaEvent_t start_event, end_event; + + for (int iter = 0; iter < (perf_test ? (perf_test_iter+1) : 1); ++iter) + { + if (perf_test && iter == 1) + { + result = cudaEventCreate(&start_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventCreate(&end_event); + EXPECT_EQ(result, cudaSuccess); + + result = cudaEventRecord(start_event); + EXPECT_EQ(result, cudaSuccess); + } + + if (beta == ElementCD(0)) + { + if (alpha == ElementCD(1)) + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + } + else + { + cutlass::gemm::kernel::GemvBatchedStrided<<< grid, block >>>( + problem_size, + alpha, + beta, + matrix_A.device_ref(), + matrix_A.capacity(), + matrix_B.device_ref(), + matrix_B.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity(), + matrix_C_computed.device_ref(), + matrix_C_computed.capacity() + ); + } + + if (iter == 0) + { + result = cudaGetLastError(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + } + } + + if (perf_test) + { + result = cudaEventRecord(end_event); + EXPECT_EQ(result, cudaSuccess); + } + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) << " kernel error: " << cudaGetErrorString(result); + + if (perf_test) + { + float ms; + result = cudaEventElapsedTime(&ms, start_event, end_event); + EXPECT_EQ(result, cudaSuccess); + + double flops = (double(problem_size.m()) * + double(problem_size.n()) * + double(problem_size.k()) * + double(problem_size.batch()) * 2); // 2 for MAC + + double read_bytes = double(problem_size.batch()) * (sizeof(ElementA)*double(problem_size.m())*double(problem_size.k()) + + sizeof(ElementB)*double(problem_size.k())*double(problem_size.n())); + + double write_bytes = double(problem_size.batch()) * (sizeof(ElementCD)*double(problem_size.m())*double(problem_size.n())); + + double avg_runtime = double(ms) / perf_test_iter; + double gflops_per_sec = flops / 1.0e6 / avg_runtime; + double read_bandwidth = read_bytes / 1.0e6 / avg_runtime; + double write_bandwidth = write_bytes / 1.0e6 / avg_runtime; + + std::cout << "\n\nProblem size: " + << problem_size.m() + << " x " << problem_size.n() + << " x " << problem_size.k() + << " x " << problem_size.batch() + << std::endl; + + std::cout << " GFLOPs: " << gflops_per_sec << std::endl; + std::cout << "BW (R/W): " << read_bandwidth << " / " << write_bandwidth << " GB/sec" << std::endl; + std::cout << " Runtime: " << avg_runtime << " ms" << std::endl; + } + else + { + matrix_C_computed.sync_host(); + + // Compute the batched gemms + for (int b = 0; b < problem_size.batch(); b++) + { + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size.mnk(), alpha, + matrix_A.host_ref(b * matrix_A.capacity()), + matrix_B.host_ref(b * matrix_B.capacity()), beta, + matrix_C_reference.host_ref(b * matrix_C_computed.capacity())); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(b * matrix_C_computed.capacity()), + matrix_C_reference.host_view(b * matrix_C_reference.capacity())); + + EXPECT_TRUE(passed) + //<< "A:\n" << matrix_A.host_view() << "\n" + //<< "B:\n" << matrix_B.host_view() << "\n" + << "Batch: " << b << "\n" + << "Reference:\n" + << matrix_C_reference.host_view(b * matrix_C_reference.capacity()) + << "\n" + << "Computed:\n" + << matrix_C_computed.host_view(b * matrix_C_computed.capacity()) + << "\n"; + } + } +} + +template +void batched_gemv_kernel_perf_test(cutlass::gemm::BatchedGemmCoord problem_size, + ElementCD_ alpha = ElementCD_(1), + ElementCD_ beta = ElementCD_(0), + int iter = 50) +{ + batched_gemv_kernel_test(problem_size, alpha, beta, true, iter); +} + +} // namespace threadblock +} // namespace kernel +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h new file mode 100644 index 0000000000000000000000000000000000000000..6e3d6ab079d44345f2f55f4126ba3efc1eba47cb --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/host/testbed_host.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/layout/vector.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + using Btype = typename Mma::ElementB; + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK), false); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN), false); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::detail::RandomUniformFunc< ElementA > tfill_rand_func( + 0, // seed + 10, // max + 0, // min + 0); // bits after decimal + + cutlass::reference::host::detail::TensorFillRandomUniformFunc< ElementA, LayoutA > tfill_rand( + tensor_A.host_view(), + tfill_rand_func); + + for (auto i=0; i< Shape::kM; i++) + for (auto j=0; j< Shape::kK; j++) + tfill_rand(cutlass::make_Coord(i,j)); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + + // Host side call + kernel( + tensor_D_computed.host_data(), + tensor_A.host_data(), + tensor_B.host_data(), + tensor_C.host_data()); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8d34d7992b57cefa0eaf7300a5e1fb49f41a93e2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/thread/testbed.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/gemm/thread/mma.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace test { +namespace gemm { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Runs the test + bool run() { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..1f3bc8cf114d7eb2ac00bd19ae92c984558b7228 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_sparse_testbed.h @@ -0,0 +1,435 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sparse_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma_sparse(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, + typename Mma::IteratorE::Params params_E, + typename Mma::IteratorE::TensorRef ref_E) { + // Shared storage needed by threadblock-scoped matrix multiply- + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + cutlass::MatrixCoord tb_offset_E{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k() / Mma::kSparse}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k() / Mma::kSparse}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorE iterator_E( + params_E, ref_E.data(), + {problem_size.m(), + problem_size.k() / Mma::kSparse / Mma::kElementsPerElementE}, + tb_thread_id, tb_offset_E); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, iterator_E, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct SparseTestbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ElementE = typename MmaCore::ElementE; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using ThreadMapE = typename MmaCore::IteratorThreadMapE; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + using AccessTypeE = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + static cutlass::arch::CacheOperation::Kind const CacheOpE = + MmaCore::kCacheOpE; + + static int const Sparse = MmaCore::kSparse; + static int const MetaSizeInBits = MmaCore::kMetaSizeInBits; + static int const MaxID2 = MmaCore::kMaxID2; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = typename MmaCore::GmemLayoutE; + + static int const ElementsPerElementE = MmaCore::kElementsPerElementE; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define iterators over tiles from the E operand + using IteratorE = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementE, ReorderedLayoutE, 1, ThreadMapE, AccessTypeE>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::SparseMmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, IteratorE, typename MmaCore::SmemIteratorE, CacheOpE, + typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_A_uncompressed; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_E; + cutlass::HostTensor matrix_E_reordered; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k / Sparse)); + matrix_A_uncompressed.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_E.reset(cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + matrix_E_reordered.reset( + cutlass::make_Coord(m, k / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + // Waive the test + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + matrix_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(matrix_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta(matrix_E_reordered.host_ref(), matrix_E.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / Sparse / ElementsPerElementE}); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + matrix_E_reordered.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + typename IteratorE::Params params_E(matrix_E_reordered.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma_sparse, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma_sparse + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), params_E, + matrix_E_reordered.device_ref()); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::uncompress(matrix_A_uncompressed.host_ref(), matrix_A.host_ref(), + matrix_E.host_ref(), problem_size.m(), + problem_size.k()); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm(problem_size, ElementC(alpha), + matrix_A_uncompressed.host_view(), matrix_B.host_view(), + ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "E:\n" << matrix_E.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..5caaf38ace92758bbc86970d8d4ff339d87348ab --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed.h @@ -0,0 +1,372 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/core_io.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +namespace test { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, threadIdx.x); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + CacheOpA, IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, + LayoutC, typename MmaCore::MmaPolicy, Stages>; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + // + // Determine SMEM requirements and waive if not satisfied + // + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + if (result != cudaSuccess) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + } + + test::gemm::threadblock::kernel_multistage_mma + <<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed && CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cout + << __FILE__ << ":" << __LINE__ << " " + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_reference.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(matrix_C_computed.host_view()), 0); + + return passed; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..4e617d6327594570b1a88a5b28f2ec4d0467b534 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_multistage_testbed_slicedk.h @@ -0,0 +1,387 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_multistage_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + + // Dynamic shared memory base pointer + extern __shared__ int GemmSharedStorageBase[]; + + // Declare pointer to dynamic shared memory. + typename Mma::SharedStorage *shared_storage = + reinterpret_cast(GemmSharedStorageBase); + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = __shfl_sync(0xffffffff, threadIdx.y, 0); + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(*shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = MmaCore::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + MmaCore::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + MmaCore::kCacheOpB; + + // Define iterators over tiles from the A operand + using IteratorA = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA, AccessTypeA>; + + // Define iterators over tiles from the B operand + using IteratorB = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB, AccessTypeB>; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, CacheOpA, + IteratorB, typename MmaCore::SmemIteratorB, CacheOpB, ElementC, LayoutC, + typename MmaCore::MmaPolicy, Stages>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_ = float(1), float beta_ = float(0)) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + cudaError_t result; + + int smem_size = int(sizeof(typename Mma::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributeMaxDynamicSharedMemorySize error: " + << cudaGetErrorString(result); + + result = cudaFuncSetAttribute( + test::gemm::threadblock::kernel_multistage_mma, + cudaFuncAttributePreferredSharedMemoryCarveout, 100); + + EXPECT_EQ(result, cudaSuccess) + << " cudaFuncSetAttribute " + "cudaFuncAttributePreferredSharedMemoryCarveout error: " + << cudaGetErrorString(result); + } + + test::gemm::threadblock::kernel_multistage_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_multistage_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb62f9a39fe4472f77446efc591267001758c58 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed.h @@ -0,0 +1,353 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_, + /// Number of stages + int Stages = 2> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + static const int kStages = Stages; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define MmaPipeline Single Stage + using MmaPipelineSingleStage = cutlass::gemm::threadblock::MmaSingleStage< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define MmaPipeline Two Stages + using MmaPipelineTwoStages = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + // Define the threadblock-scoped pipelined matrix multiply (Select between Single vs. Two stages) + using Mma = typename cutlass::platform::conditional<(kStages==1), MmaPipelineSingleStage, MmaPipelineTwoStages>::type; + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed; + cutlass::HostTensor matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + bool sufficient() { + return true; + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + if (CUTLASS_TEST_UNIT_ENABLE_WARNINGS) { + std::cerr << "Test waived due to insufficient CUDA device." << std::endl; + } + return true; + } + + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill(matrix_C_computed.host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result) << " on device " << GetCudaDevice(); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed) << "Failed on device " << GetCudaDevice(); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h new file mode 100644 index 0000000000000000000000000000000000000000..36e55b2542b2258542336a052cdd14bf4b85f78d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_pipelined_testbed_slicedk.h @@ -0,0 +1,370 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma(cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::TensorRef ref_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::TensorRef ref_B, + typename Mma::ElementC **ptr_C, + typename Mma::LayoutC::Stride::Index ldc) { + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A(params_A, ref_A.data(), + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorB iterator_B(params_B, ref_B.data(), + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + int partitionsK_idx = warp_id / (Mma::WarpCount::kM * Mma::WarpCount::kN); + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C[partitionsK_idx], ldc}, lane_id); + + + int warp_idx_mn = warp_id % (Mma::WarpCount::kM * Mma::WarpCount::kN); + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_idx_mn % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_idx_mn / Mma::WarpCount::kM)}); + + iterator_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename MmaCore_> +struct Testbed { + /// Threadblock-level GEMM implementation + using MmaCore = MmaCore_; + using ThreadblockShape = typename MmaCore::Shape; + using WarpShape = typename MmaCore::WarpShape; + using InstructionShape = typename MmaCore::InstructionShape; + using ElementA = typename MmaCore::ElementA; + using LayoutA = typename MmaCore::LayoutA; + using ElementB = typename MmaCore::ElementB; + using LayoutB = typename MmaCore::LayoutB; + using ElementC = typename MmaCore::ElementC; + using LayoutC = typename MmaCore::LayoutC; + + // Define iterators over tiles from the A operand + static const bool use_idp4a = cutlass::platform::is_same::value && + cutlass::platform::is_same::value && + cutlass::platform::is_same::value; + + static const bool transposeA = cutlass::platform::is_same< LayoutA, cutlass::layout::ColumnMajor >::value; + static const bool transposeB = cutlass::platform::is_same< LayoutB, cutlass::layout::RowMajor >::value; + + using IteratorA = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA, transposeA> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore::IteratorThreadMapA> + >::type; + + // Define iterators over tiles from the B operand + using IteratorB = typename cutlass::platform::conditional< use_idp4a, + cutlass::transform::threadblock::PredicatedTileIterator2dThreadTile< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB, transposeB> , + + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore::IteratorThreadMapB> + >::type; + + // Define the threadblock-scoped pipelined matrix multiply + using Mma = cutlass::gemm::threadblock::MmaPipelined< + typename MmaCore::Shape, IteratorA, typename MmaCore::SmemIteratorA, + IteratorB, typename MmaCore::SmemIteratorB, ElementC, LayoutC, + typename MmaCore::MmaPolicy>; + + static int const kPartitionsK = MmaCore::MmaPolicy::kPartitionsK; + + // + // Data members + // + + cutlass::HostTensor matrix_A; + cutlass::HostTensor matrix_B; + cutlass::HostTensor matrix_C_computed[kPartitionsK]; + cutlass::HostTensor matrix_C_reference; + cutlass::HostTensor matrix_C_pointers; + + cutlass::gemm::GemmCoord problem_size; + float alpha, beta; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed(int m, int n, int k, float alpha_, float beta_) + : problem_size(m, n, k), alpha(alpha_), beta(beta_) { + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].reset(cutlass::make_Coord(m, n)); + + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + matrix_C_pointers.reset(cutlass::Coord<1>(kPartitionsK)); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + } else { + return false; + } + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + cutlass::reference::host::TensorFill(matrix_C_computed[k].host_view()); + + cutlass::reference::host::TensorFill(matrix_C_reference.host_view()); + + matrix_A.sync_device(); + matrix_B.sync_device(); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_pointers.at(cutlass::Coord<1>(k)) = matrix_C_computed[k].device_data(); + + matrix_C_pointers.sync_device(); + + test::gemm::threadblock::kernel_mma<<>>( + problem_size, params_A, matrix_A.device_ref(), params_B, + matrix_B.device_ref(), matrix_C_pointers.device_data(), + matrix_C_computed[0].layout().stride(0)); + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < kPartitionsK; k++) + matrix_C_computed[k].sync_host(); + + // TODO: this is temporary. it will be removed after slicing can de + // reduction + // + // Reduce matrix_C_computed + // + CUTLASS_PRAGMA_UNROLL + for(int k = 1; k < kPartitionsK; k++) { + CUTLASS_PRAGMA_UNROLL + for(int m = 0; m < matrix_C_computed[0].extent().row(); m++){ + CUTLASS_PRAGMA_UNROLL + for(int n = 0; n < matrix_C_computed[0].extent().column(); n++){ + matrix_C_computed[0].at({m, n}) += matrix_C_computed[k].at({m, n}); + } + } + } + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + problem_size, ElementC(alpha), matrix_A.host_view(), + matrix_B.host_view(), ElementC(beta), matrix_C_reference.host_view()); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed[0].host_view(), matrix_C_reference.host_view()); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed[0].host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e5fdc07769726353b33c1a5da65dedfadb4ce1e7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/threadblock/mma_planar_complex_testbed.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit testbed for kernel-level GEMM +*/ + +#pragma once + +#include + +#include "../../common/cutlass_unit_test.h" + +#include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/host_tensor_planar_complex.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm_planar_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_fill.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void kernel_mma_planar_complex( + cutlass::gemm::GemmCoord problem_size, + typename Mma::IteratorA::Params params_A, + typename Mma::IteratorA::Element *ptr_A, + int64_t imaginary_stride_A, + typename Mma::IteratorB::Params params_B, + typename Mma::IteratorB::Element *ptr_B, + int64_t imaginary_stride_B, + typename Mma::ElementC *ptr_C, + typename Mma::LayoutC::Stride::Index ldc, int64_t imaginary_stride_C) { + + // Shared storage needed by threadblock-scoped matrix multiply-accumulate + __shared__ typename Mma::SharedStorage shared_storage; + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {int(blockIdx.x), int(blockIdx.y), + 0}; + + cutlass::MatrixCoord tb_offset_A{tb_tile_offset.m() * Mma::Shape::kM, + tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{tb_tile_offset.k(), + tb_tile_offset.n() * Mma::Shape::kN}; + + // Compute position within threadblock + int tb_thread_id = threadIdx.y * blockDim.x + threadIdx.x; + + // Construct iterators to A operand + typename Mma::IteratorA iterator_A_real(params_A, ptr_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + typename Mma::IteratorA iterator_A_imag(params_A, ptr_A + imaginary_stride_A, + {problem_size.m(), problem_size.k()}, + tb_thread_id, tb_offset_A); + + // Construct iterators to B operand + typename Mma::IteratorB iterator_B_real(params_B, ptr_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + typename Mma::IteratorB iterator_B_imag(params_B, ptr_B + imaginary_stride_B, + {problem_size.k(), problem_size.n()}, + tb_thread_id, tb_offset_B); + + int warp_id = threadIdx.y; + int lane_id = threadIdx.x; + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage, tb_thread_id, warp_id, threadIdx.x); + + typename Mma::FragmentC accum; + + accum.clear(); + + int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A_real, iterator_A_imag, iterator_B_real, iterator_B_imag, accum); + + // Output results + typename Mma::Operator::IteratorC iterator_C({ptr_C, ldc}, lane_id); + + iterator_C.add_tile_offset( + {(tb_tile_offset.m() * Mma::WarpCount::kM) + + (warp_id % Mma::WarpCount::kM), + (tb_tile_offset.n() * Mma::WarpCount::kN) + + (warp_id / Mma::WarpCount::kM)}); + + iterator_C.store(accum.real); + + iterator_C.store_with_pointer_offset(accum.imag, imaginary_stride_C); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Threadblock-level matrix multiply-accumulate + typename Mma_> +struct TestbedPlanarComplex { + + using Mma = Mma_; + using ThreadblockShape = typename Mma::Shape; + using IteratorA = typename Mma::IteratorA; + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using IteratorB = typename Mma::IteratorB; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename Mma::ElementC; + using ElementAccumulator = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + using ThreadMapA = typename Mma::IteratorA::ThreadMap; + using ThreadMapB = typename Mma::IteratorB::ThreadMap; + using AccessTypeA = cutlass::Array; + using AccessTypeB = cutlass::Array; + static int const Stages = Mma::kStages; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + Mma::kCacheOpA; + static cutlass::arch::CacheOperation::Kind const CacheOpB = + Mma::kCacheOpB; + + // + // Data members + // + + cutlass::HostTensorPlanarComplex matrix_A; + cutlass::HostTensorPlanarComplex matrix_B; + cutlass::HostTensorPlanarComplex matrix_C_computed; + cutlass::HostTensorPlanarComplex matrix_C_reference; + + cutlass::gemm::GemmCoord problem_size; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedPlanarComplex(int m, int n, int k) + : problem_size(m, n, k) { + + matrix_A.reset(cutlass::make_Coord(m, k)); + matrix_B.reset(cutlass::make_Coord(k, n)); + matrix_C_computed.reset(cutlass::make_Coord(m, n)); + matrix_C_reference.reset(cutlass::make_Coord(m, n), false); + } + + /// Runs the test + bool run( + dim3 grid, dim3 block, + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_A.host_view(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + + for (int i = 0; i < matrix_A.capacity() * 2; ++i) { + matrix_A.host_data()[i] = cutlass::half_t(float(i % 5) - 2); + } + /* + cutlass::reference::host::BlockFillSequential(matrix_A.host_data(), + matrix_A.capacity() * 2); + */ + } else if (init_A == cutlass::Distribution::Identity) { + //cutlass::reference::host::TensorFillIdentity(matrix_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + + + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + matrix_B.host_view(), seed + 16, scope_max, scope_min, 0); + + + } else if (init_B == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(matrix_B.host_data(), + matrix_B.capacity() * 2); + + for (int i = 0; i < matrix_B.capacity() * 2; ++i) { + matrix_B.host_data()[i] = cutlass::half_t(float((i + 3) % 5) - 2); + } + + + } else if (init_B == cutlass::Distribution::Identity) { + + //cutlass::reference::host::TensorFillIdentity(matrix_B.host_view()); + + } else { + return false; + } + + matrix_A.sync_device(); + matrix_B.sync_device(); + matrix_C_computed.sync_device(); + + typename IteratorA::Params params_A(matrix_A.layout()); + typename IteratorB::Params params_B(matrix_B.layout()); + + test::gemm::threadblock::kernel_mma_planar_complex<<>>( + problem_size, + params_A, + matrix_A.device_data(), + matrix_A.imaginary_stride(), + params_B, + matrix_B.device_data(), + matrix_B.imaginary_stride(), + matrix_C_computed.device_data(), + matrix_C_computed.layout().stride(0), + matrix_C_computed.imaginary_stride() + ); + + + // + // Check error code + // + + cudaError_t result = cudaDeviceSynchronize(); + EXPECT_EQ(result, cudaSuccess) + << " kernel error: " << cudaGetErrorString(result); + + matrix_C_computed.sync_host(); + + cutlass::reference::host::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator + >( + problem_size, + cutlass::complex(ElementAccumulator(1)), + matrix_A.host_ref(), + Mma::kTransformA, + matrix_B.host_ref(), + Mma::kTransformB, + cutlass::complex(ElementAccumulator(0)), + matrix_C_reference.host_ref(), + matrix_C_reference.host_ref() + ); + + bool passed = cutlass::reference::host::TensorEquals( + matrix_C_computed.host_view(), + matrix_C_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::ofstream output("mma_pipelined_testbed_errors.txt"); + + output + << "A:\n" << matrix_A.host_view() << "\n" + << "B:\n" << matrix_B.host_view() << "\n" + << "Reference:\n" + << matrix_C_reference.host_view() << "\n" + << "Computed:\n" + << matrix_C_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..921d1abdc40c2040104815cfffb8b2ea32384136 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/gemm/warp/testbed.h @@ -0,0 +1,1543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/platform/platform.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" + +namespace test { +namespace gemm { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + + ++iter_A; + ++iter_B; + + mma(accum, frag_A, frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The inner product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_A.host_data(), + tensor_A.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + + cutlass::reference::host::BlockFillRandomUniform(tensor_B.host_data(), + tensor_B.capacity(), seed, scope_max, scope_min, 0); + + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] + << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] + << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void kernel_transform( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementA, ThreadblockShape::kM * ThreadblockShape::kK> smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + + using TransformedFragmentA = typename Mma::TransformedFragmentA; + using TransformedFragmentB = typename Mma::TransformedFragmentB; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed({ThreadblockShape::kM, ThreadblockShape::kK}); + typename Mma::LayoutB layout_B = Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + FragmentA loaded_frag_A; + FragmentB loaded_frag_B; + TransformedFragmentA transformed_frag_A; + TransformedFragmentB transformed_frag_B; + + FragmentC accum; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(loaded_frag_A); + iter_B.load(loaded_frag_B); + + ++iter_A; + ++iter_B; + + mma.transform(transformed_frag_A, transformed_frag_B, loaded_frag_A, + loaded_frag_B); + + mma(accum, transformed_frag_A, transformed_frag_B, accum); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct TransformTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<>>( + tensor_D_computed.device_data(), tensor_A.device_data(), + tensor_B.device_data(), tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() << "):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_ +> +struct TransformedTestbedComplex { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + TransformedTestbedComplex() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_A.host_view(), + seed, 8, -8, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform(tensor_B.host_view(), + seed + 16, 8, -8, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + + // launch kernel + kernel_transform<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + cutlass::reference::host::GemmComplex( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A.host_ref(), + Mma::kTransformA, + tensor_B.host_ref(), + Mma::kTransformB, + ElementC(0), + tensor_C.host_ref(), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + + cutlass::TensorView tensor_A_physical( + tensor_A.host_data(), + tensor_A.stride()[0], + tensor_A.extent()); + + cutlass::TensorView tensor_B_physical( + tensor_B.host_data(), + tensor_B.stride()[0], + tensor_B.extent()); + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "A(physical - stride: " << tensor_A.stride()[0] << ", extent: " << tensor_A.extent() << "):\n" << tensor_A_physical << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout + << "B:\n" << tensor_B.host_view() << "\n\n" + << "B(physical - stride: " << tensor_B.stride()[0] << ", extent: " << tensor_B.extent() <<"):\n" << tensor_B_physical << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Test kernel +template +__global__ void sparse_kernel( + typename Mma::ElementC *output_C, + typename Mma::ElementA const *input_A, + typename Mma::ElementB const *input_B, + typename Mma::ElementC const *input_C, + typename Mma::ElementE const *input_E, + int iterations = 1) { + + // Use AlignedBuffer to store trivially copyable objects in unions and __shared__ buffers. + __shared__ cutlass::AlignedBuffer + smem_buffer_A; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementB, ThreadblockShape::kN * ThreadblockShape::kK> smem_buffer_B; + + __shared__ cutlass::AlignedBuffer< + typename Mma::ElementE, Mma::Shape::kM * Mma::Shape::kK / + Mma::kSparse / Mma::kElementsPerElementE> + smem_buffer_E; + + __syncthreads(); + + if (threadIdx.x == 0) { + typename Mma::ElementA *smem_ptr_A = smem_buffer_A.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_A.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_A, i) = + cutlass::ReferenceFactory::type>::get(input_A, i); + } + + typename Mma::ElementB *smem_ptr_B = smem_buffer_B.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_B.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_B, i) = + cutlass::ReferenceFactory::type>::get(input_B, i); + } + + typename Mma::ElementE *smem_ptr_E = smem_buffer_E.data(); + #pragma unroll 1 + for (size_t i = 0; i < smem_buffer_E.size(); ++i) { + cutlass::ReferenceFactory::get(smem_ptr_E, i) = + cutlass::ReferenceFactory::type>::get(input_E, i); + } + } + + __syncthreads(); + + // + // Construct warp-level matrix product + // + + using FragmentA = typename Mma::FragmentA; + using FragmentB = typename Mma::FragmentB; + using FragmentC = typename Mma::FragmentC; + using FragmentE = typename Mma::FragmentE; + + typename Mma::LayoutA layout_A = Mma::LayoutA::packed( + {ThreadblockShape::kM, ThreadblockShape::kK / Mma::kSparse}); + typename Mma::LayoutB layout_B = + Mma::LayoutB::packed({ThreadblockShape::kK, ThreadblockShape::kN}); + typename Mma::LayoutC layout_C = Mma::LayoutC::packed({Mma::Shape::kM, Mma::Shape::kN}); + typename Mma::LayoutE layout_E = + Mma::LayoutE::packed({Mma::Shape::kM * Mma::kInterleaved, + Mma::Shape::kK / Mma::kSparse / + Mma::kElementsPerElementE / Mma::kInterleaved}); + + typename Mma::IteratorA iter_A({smem_buffer_A.data(), layout_A}, cutlass::arch::LaneId()); + + typename Mma::IteratorB iter_B({smem_buffer_B.data(), layout_B}, cutlass::arch::LaneId()); + + typename Mma::IteratorE iter_E({smem_buffer_E.data(), layout_E}, cutlass::arch::LaneId()); + + FragmentA frag_A; + FragmentB frag_B; + + FragmentC accum; + + FragmentE frag_E; + + Mma mma; + + accum.clear(); + + CUTLASS_PRAGMA_NO_UNROLL + for (int iter = 0; iter < iterations; ++iter) { // place in loop that is not unrolled + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < ThreadblockShape::kK; + k += Mma::Policy::MmaShape::kK) { + iter_A.load(frag_A); + iter_B.load(frag_B); + iter_E.load(frag_E); + + ++iter_A; + ++iter_B; + ++iter_E; + + mma(accum, frag_A, frag_B, accum, frag_E); + } + } + + typename Mma::IteratorC iter_C({output_C, layout_C}, cutlass::arch::LaneId()); + + iter_C.store(accum); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Warp-level matrix multiply-accumulate + typename Mma_, + /// Size of threadblock-scoped shape used to store SMEM + typename ThreadblockShape_, + /// The innter product operation performed by GEMM + typename Operator_ = cutlass::arch::OpMultiplyAdd +> +struct SparseTestbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = Mma_; + using ThreadblockShape = ThreadblockShape_; + using Operator = Operator_; + + using Shape = typename Mma::Shape; + using ElementA = typename Mma::ElementA; + using LayoutA = typename Mma::LayoutA; + using ElementB = typename Mma::ElementB; + using LayoutB = typename Mma::LayoutB; + using ElementC = typename Mma::ElementC; + using LayoutC = typename Mma::LayoutC; + + static int const Sparse = Mma::kSparse; + static int const MetaSizeInBits = Mma::kMetaSizeInBits; + static int const MaxID2 = Mma::kMaxID2; + static int const Interleaved = Mma::kInterleaved; + + using ElementE = typename Mma::ElementE; + + static int const ElementsPerElementE = Mma::kElementsPerElementE; + + using LayoutE = cutlass::layout::RowMajor; + using ReorderedLayoutE = + cutlass::layout::ColumnMajorInterleaved; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_uncompressed; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_E_reordered; + + // + // Methods + // + + /// Allocates workspace in device memory + SparseTestbed() { + + tensor_A.reset(cutlass::make_Coord(ThreadblockShape::kM, + ThreadblockShape::kK / Sparse)); + tensor_A_uncompressed.reset( + cutlass::make_Coord(ThreadblockShape::kM, ThreadblockShape::kK)); + tensor_B.reset(cutlass::make_Coord(ThreadblockShape::kK, ThreadblockShape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + tensor_E.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + tensor_E_reordered.reset(cutlass::make_Coord( + Shape::kM, Shape::kK / Sparse / ElementsPerElementE)); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.major == 9) { + // NVIDIA Hopper drops support for several data types + if ( + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8 || + cutlass::sizeof_bits::value < 8) { + + return false; + } + } + + return true; + } + + /// Runs the test + bool run( + cutlass::Distribution::Kind init_A = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E = cutlass::Distribution::Uniform) { + + if (!sufficient()) { + return true; + } + + // + // initialize device memory + // + + if (init_A == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_A.host_view(), seed, scope_max, scope_min, 0); + } else if (init_A == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_A.host_data(), + tensor_A.capacity()); + } else if (init_A == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_A.host_view()); + } else { + return false; + } + + if (init_B == cutlass::Distribution::Uniform) { + int scope_max = 8; + int scope_min = -8; + + if (cutlass::sizeof_bits::value == 4) { + scope_max = 2; + scope_min = -2; + } else if (cutlass::sizeof_bits::value == 1) { + scope_max = 2; + scope_min = 0; + } + + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomUniform( + tensor_B.host_view(), seed + 16, scope_max, scope_min, 0); + } else if (init_B == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(tensor_B.host_data(), + tensor_B.capacity()); + } else if (init_B == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(tensor_B.host_view()); + } else { + return false; + } + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + if (init_E == cutlass::Distribution::Uniform) { + uint64_t seed = 7; + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_E.host_view(), seed, MetaSizeInBits); + } else if (init_E == cutlass::Distribution::Identity) { + uint32_t content = (MaxID2 == 1) ? 0x44444444 : 0x4444; + cutlass::reference::host::TensorFill(tensor_E.host_view(), + (ElementE)(content)); + } else { + return false; + } + + cutlass::reorder_meta( + tensor_E_reordered.host_ref(), tensor_E.host_ref(), + {Shape::kM, Shape::kN, Shape::kK / Sparse / ElementsPerElementE}); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + tensor_E_reordered.sync_device(); + + // launch kernel + sparse_kernel<<< dim3(1, 1), dim3(32, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_E_reordered.device_data()); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + cutlass::uncompress(tensor_A_uncompressed.host_ref(), tensor_A.host_ref(), + tensor_E.host_ref(), Shape::kM, Shape::kK); + + cutlass::reference::host::Gemm + reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, ThreadblockShape::kK}, + ElementC(1), + tensor_A_uncompressed.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + EXPECT_TRUE(passed); + + if (!passed) { + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "A:\n" << tensor_A.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "B:\n" << tensor_B.host_view() << "\n\n"; + + std::cout <<"cutlass::sizeof_bits::value = "<::value<<"\n"; + std::cout << "E:\n" << tensor_E.host_view() << "\n\n"; + + std::cout + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << "\n"; + } + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..3311e915db892466a9a4c52c82d100c2e1319966 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/cutlass/nvrtc/environment.h @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace nvrtc { + +extern char const *kCutlassHeaders[]; +extern char const *kCutlassHeaderNames[]; +extern size_t const kCutlassHeaderCount; +} // namespace nvrtc +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55df44379c847034ed38cfab23477331ee4a537c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/contraction.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/tensor.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + + +namespace nvrtc { +namespace thread { + +template< + typename ElementA, typename ElementB, typename ElementC, + typename TileShape, typename ClusterShape, + bool kTransA, bool kTransB, + int RANK_M, int RANK_N, int RANK_K, int RANK_L +> +struct ContractionKernel { + +using ElementScalar = float; +using ElementAccum = float; +using EpilogueThread = cutlass::epilogue::thread::LinearCombination; + +static constexpr cute::GMMA::Major majorA = ! kTransA ? cute::GMMA::Major::MN : cute::GMMA::Major::K; +static constexpr cute::GMMA::Major majorB = ! kTransB ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + +/// Kernel config +typedef int64_t stride_type; +typedef int32_t extent_type; + +static constexpr const stride_type* stride_null = nullptr; +static constexpr const extent_type* extent_null = nullptr; + +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +using StrideA = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideB = decltype(cute::make_stride( + make_stride_tuple(stride_null, 0, 0), + make_stride_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using StrideC = decltype(cute::make_stride( + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0), + cute::make_int_tuple(stride_null, 0, 0))); + +using ProblemShape = decltype(cute::make_shape( + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0), + cute::make_int_tuple(extent_null, 0, 0))); + +using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 16 / sizeof(ElementA), + ElementB, StrideB, 16 / sizeof(ElementB), + ElementAccum, + TileShape, ClusterShape, cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelTmaWarpSpecialized +>::CollectiveOp; + +using EpilogueOutputOp = cutlass::epilogue::collective::DefaultEpilogue; +using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter; +using Kernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveOp, + CollectiveEpilogue>; + +}; + +} // namespace nvrtc +} // namespace thread diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..576f55cd868cd64c8c09c055d8b9a956e40c87ae --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/kernel/thread/testbed_kernel.h @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include "cutlass/array.h" + +namespace test { +namespace nvrtc { +namespace kernel { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level matrix multiply-accumulate +template +__global__ void testbed_kernel( + typename Mma::ElementC *D, + typename Mma::ElementA const *A, + typename Mma::ElementB const *B, + typename Mma::ElementC const *C) { + + auto ptr_D = reinterpret_cast *>(D); + auto ptr_A = reinterpret_cast const *>(A); + auto ptr_B = reinterpret_cast const *>(B); + auto ptr_C = reinterpret_cast const *>(C); + + Mma mma; + + auto a = *ptr_A; + auto b = *ptr_B; + auto c = *ptr_C; + + cutlass::Array d; + + mma(d, a, b, c); + + *ptr_D = d; +} + +} +} +} +} + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h new file mode 100644 index 0000000000000000000000000000000000000000..c7e6e94691c82b2f343959421c884c8b0b06f9b4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/assert.h @@ -0,0 +1,30 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h new file mode 100644 index 0000000000000000000000000000000000000000..5ba5432fd568af71e15b20b8cdab1571f303bcdf --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/stdlib/stdint.h @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +typedef char int8_t; +typedef unsigned char uint8_t; +typedef short int16_t; +typedef unsigned short uint16_t; +typedef int int32_t; +typedef unsigned int uint32_t; +typedef long long int int64_t; +typedef unsigned long long int uint64_t; + +#if defined __x86_64__ && !defined __ILP32__ +# define __WORDSIZE 64 +#else +# define __WORDSIZE 32 +#endif + + +/* Small types. */ + +/* Signed. */ +typedef signed char int_least8_t; +typedef short int int_least16_t; +typedef int int_least32_t; +#if __WORDSIZE == 64 +typedef long int int_least64_t; +#else +__extension__ +typedef long long int int_least64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_least8_t; +typedef unsigned short int uint_least16_t; +typedef unsigned int uint_least32_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_least64_t; +#else +__extension__ +typedef unsigned long long int uint_least64_t; +#endif + + +/* Fast types. */ + +/* Signed. */ +typedef signed char int_fast8_t; +#if __WORDSIZE == 64 +typedef long int int_fast16_t; +typedef long int int_fast32_t; +typedef long int int_fast64_t; +#else +typedef int int_fast16_t; +typedef int int_fast32_t; +__extension__ +typedef long long int int_fast64_t; +#endif + +/* Unsigned. */ +typedef unsigned char uint_fast8_t; +#if __WORDSIZE == 64 +typedef unsigned long int uint_fast16_t; +typedef unsigned long int uint_fast32_t; +typedef unsigned long int uint_fast64_t; +#else +typedef unsigned int uint_fast16_t; +typedef unsigned int uint_fast32_t; +__extension__ +typedef unsigned long long int uint_fast64_t; +#endif + +/* Types for `void *' pointers. */ +#if __WORDSIZE == 64 +# ifndef __intptr_t_defined +typedef long int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned long int uintptr_t; +#else +# ifndef __intptr_t_defined +typedef int intptr_t; +# define __intptr_t_defined +# endif +typedef unsigned int uintptr_t; +#endif + + +/* Largest integral types. */ +#if __WORDSIZE == 64 +typedef long int intmax_t; +typedef unsigned long int uintmax_t; +#else +__extension__ +typedef long long int intmax_t; +__extension__ +typedef unsigned long long int uintmax_t; +#endif + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..8fd6863e8fa003d3fbc4e0b498e3b9b454ade190 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/nvrtc/thread/testbed.h @@ -0,0 +1,398 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level GEMM +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/gemm/thread/mma.h" +#include "../kernel/thread/testbed_kernel.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/trace.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include +#include +#include "../cutlass/nvrtc/environment.h" +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test { +namespace nvrtc { +namespace thread { + +#define NVRTC_RETURN_IF_ERROR(api) \ + do { \ + nvrtcResult _result = api; \ + if (_result != NVRTC_SUCCESS) { \ + CUTLASS_TRACE_HOST("Nvrtc error: " << _result); \ + return false; \ + } \ + } while(0) + +inline const char * cuda_source_fmt = R"""( + +#include "kernel/thread/contraction.hpp" + +using Operator = %s; + +extern "C" __global__ void global_entry(__grid_constant__ Operator::Params const params) { + extern __shared__ char smem[]; + + Operator op; + op(params, smem); +} + +)"""; + +struct TestbedKernel { + static bool compile(std::string const &kernel, std::vector const &opts) { + int sz = std::snprintf(nullptr, 0, cuda_source_fmt, kernel.c_str()); + std::vector cuda_source(sz + 1); + std::snprintf(&cuda_source[0], cuda_source.size(), cuda_source_fmt, kernel.c_str()); + + nvrtcProgram program; + NVRTC_RETURN_IF_ERROR( + nvrtcCreateProgram( + &program, + cuda_source.data(), + nullptr, + static_cast(cutlass::nvrtc::kCutlassHeaderCount), + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames) + ); + + nvrtcResult compile_result = + nvrtcCompileProgram( + program, + static_cast(opts.size()), + opts.data()); + + size_t log_size; + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLogSize(program, &log_size) + ); + + if (log_size > 1) { + auto log = std::make_unique(log_size); + + NVRTC_RETURN_IF_ERROR( + nvrtcGetProgramLog(program, log.get()) + ); + + std::cout << log.get() << std::endl; + } + + NVRTC_RETURN_IF_ERROR(compile_result); + + NVRTC_RETURN_IF_ERROR( + nvrtcDestroyProgram(&program) + ); + + return true; + } +}; + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA, + /// Data type of B elements + typename ElementB, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB, + /// Element type of C matrix + typename ElementC, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC +> +struct Testbed { + + /// Thread-level matrix multiply-accumulate operator + using Mma = cutlass::gemm::thread::Mma< + Shape, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC + >; + + // + // Data members + // + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_D_computed; + cutlass::HostTensor tensor_D_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed() { + + tensor_A.reset(cutlass::make_Coord(Shape::kM, Shape::kK)); + tensor_B.reset(cutlass::make_Coord(Shape::kK, Shape::kN)); + tensor_C.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_computed.reset(cutlass::make_Coord(Shape::kM, Shape::kN)); + tensor_D_reference.reset(cutlass::make_Coord(Shape::kM, Shape::kN), false); + } + + static inline bool check_nvrtc_error(nvrtcResult error) { + if (error != NVRTC_SUCCESS) { + std::cerr << "failed to compile "; + return false; + } + return true; + } + + /// Runs the test + bool run(std::string const &gemm_traits) { + + // + // initialize device memory + // + + cutlass::reference::host::BlockFillSequential( + tensor_A.host_data(), + tensor_A.capacity() + ); + + cutlass::reference::host::BlockFillSequential( + tensor_B.host_data(), + tensor_B.capacity(), + ElementB(1), + ElementB(2) + ); + + cutlass::reference::host::TensorFill( + tensor_C.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_computed.host_view(), + ElementC(0) + ); + + cutlass::reference::host::TensorFill( + tensor_D_reference.host_view(), + ElementC(0) + ); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D_computed.sync_device(); + +#if 0 + // launch kernel + cutlass::gemm::kernel::testbed_kernel<<< dim3(1, 1), dim3(1, 1, 1) >>>( + tensor_D_computed.device_data(), + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data()); + +#else + // Instantiate gemm_kernel + nvrtcResult result_nvrtc; + nvrtcProgram program; + static char const *src = + "#include \"cutlass/gemm/thread/mma.h\"\n" + "#include \"cutlass/gemm/gemm.h\"\n" + "#include \"cutlass/layout/matrix.h\"\n" + "#include \"unit/nvrtc/kernel/thread/testbed_kernel.h\"\n" + ; + + std::string type_name; +#if 0 + // TODO Ideally we'd use nvrtcGetTypeName to determine the type, but it cannot resolve enum symbol names + // As altername solution we might want to implement to_string() to get the traits string. + nvrtcGetTypeName(&type_name); +#else + type_name = gemm_traits; +#endif + + result_nvrtc = nvrtcCreateProgram(&program, + src, + NULL, + (int)cutlass::nvrtc::kCutlassHeaderCount, + cutlass::nvrtc::kCutlassHeaders, + cutlass::nvrtc::kCutlassHeaderNames); + check_nvrtc_error(result_nvrtc); + + std::string gemm_kernel_instantiation = + "test::nvrtc::kernel::thread::testbed_kernel< " + type_name + " >"; + nvrtcAddNameExpression(program, gemm_kernel_instantiation.c_str()); + + const char *opts[] = {"--gpu-architecture=compute_75", + "--std=c++17", + "--include-path=/usr/local/cuda-10.1/include"}; + + result_nvrtc = nvrtcCompileProgram(program, 3, opts); + if (result_nvrtc != NVRTC_SUCCESS) { + size_t logSize; + nvrtcGetProgramLogSize(program, &logSize); + std::vector log(logSize); + nvrtcGetProgramLog(program, log.data()); + std::cout << "Compile log:" << std::endl << log.data() << std::endl; + } + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // The lowered name is the name of the template instantiation in the generated PTX code. + char const *gemm_kernel_lowered_name; + nvrtcGetLoweredName(program, gemm_kernel_instantiation.c_str(), &gemm_kernel_lowered_name); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // Query the size of the genereated PTX so that we can allocate storage and retrieve it afterwards + size_t ptx_size; + result_nvrtc = nvrtcGetPTXSize(program, &ptx_size); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + std::vector ptx(ptx_size); + result_nvrtc = nvrtcGetPTX(program, ptx.data()); + if (!check_nvrtc_error(result_nvrtc)) { + assert(0); + } + + // we do not need the nvrtc program anymore + //nvrtcDestroyProgram(&program); + + CUmodule module; + CUresult result_cuda; + result_cuda = cuModuleLoadDataEx(&module, ptx.data(), 0, 0, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + CUfunction kernel; + result_cuda = cuModuleGetFunction(&kernel, module, gemm_kernel_lowered_name); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } + + void* d_a = (void*)tensor_A.device_data(); + void* d_b = (void*)tensor_B.device_data(); + void* d_c = (void*)tensor_C.device_data(); + void* d_d = (void*)tensor_D_computed.device_data(); + void* args[] = { &d_d, &d_a, &d_b, &d_c }; + + // CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, unsigned int blockDimX, unsigned int blockDimY, unsigned int blockDimZ, unsigned int sharedMemBytes, CUstream hStream, void** kernelParams, void** extra + result_cuda = cuLaunchKernel(kernel, 1, 1, 1, 1, 1, 1, 0, 0 /*cudaStreamDefault*/, args, 0); + if (result_cuda != CUDA_SUCCESS) { + assert(0); + } else { +} +#endif + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cout << "CUDA ERROR: " << cudaGetErrorString(result); + return false; + } + + tensor_D_computed.sync_host(); + + // + // Reference implementation + // + + //tensor_D_reference.fill(tensor_C.host_view()); + + cutlass::reference::host::Gemm reference_gemm; + + reference_gemm( + {Shape::kM, Shape::kN, Shape::kK}, + ElementC(1), + tensor_A.host_ref(), + tensor_B.host_ref(), + ElementC(0), + tensor_D_reference.host_ref() + ); + + // + // Verify equivalence + // + + // compare + bool passed = cutlass::reference::host::TensorEquals( + tensor_D_computed.host_view(), + tensor_D_reference.host_view() + ); + + if(!passed) std::cout + << "A:\n" << tensor_A.host_view() << "\n\n" + << "B:\n" << tensor_B.host_view() << "\n\n" + << "C:\n" << tensor_C.host_view() << "\n\n" + << "Reference:\n" << tensor_D_reference.host_view() << "\n\n" + << "Computed:\n" << tensor_D_computed.host_view() << std::endl; + + std::cout << "passed " << passed << std::endl; + + return passed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace nvrtc +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..6cc2946a2c51cfb8c1971345c81c1910bd667208 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Common Testbed file shared by Pipeline unit tests +*/ + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" +#include "../common/cutlass_unit_test.h" + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +// Command line test options +struct Options { + // + // Data Members + // + bool help; + bool verification_enabled; + int SM_count; + int clock_MHz; + + // + // Methods + // + Options(): + help(false), + verification_enabled(true), + SM_count(116), + clock_MHz(1477) + { } + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, true); + cmd.get_cmd_line_argument("sm-count", SM_count, 116); + cmd.get_cmd_line_argument("clock", clock_MHz, 1477); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +struct Testbed { +private: + // Commandline options + Options options; + + void run_test(uint32_t const kNumIters) { + + // Run CuTe Gemm + Pipeline pipeline; + + cudaError_t result = pipeline.run(kNumIters); + + CUTE_CHECK_LAST(); + } + + +public: + Testbed(Options const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + + std::array kNumIters; + + for (size_t i = 0; i < kNumIters.size(); ++i) { + kNumIters[i] = static_cast( (rand() % 1000) + 1 ); + } + + for (int n : kNumIters) { + std::cout << "Stages = " << Pipeline::Stages << " kNumIters = " << n << "\n"; + run_test(n); + } + + return true; + } +}; diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h new file mode 100644 index 0000000000000000000000000000000000000000..50a68a1437956c95aa4e7912e93adc8b1481c9cc --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/pipeline/testbed_cluster_launch_control.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Testbed file used by cluster launch control pipeline unit test +*/ + +// + +// + +#if CUDA_12_0_SM90_FEATURES_SUPPORTED + #define CUTLASS_UNIT_TEST_PIPELINE true +#else + #define CUTLASS_UNIT_TEST_PIPELINE false +#endif + +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +// Command line test options +struct OptionsClusterLaunch { + // + // Data Members + // + bool help = false; + bool verification_enabled = true; + int SM_count = 116; + int clock_MHz = 1477; + dim3 grid_dim = {0,0,0}; + + // + // Methods + // + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("verification-enabled", verification_enabled, verification_enabled); + cmd.get_cmd_line_argument("sm-count", SM_count, SM_count); + cmd.get_cmd_line_argument("clock", clock_MHz, clock_MHz); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --verification-enabled= Enable/Disable verification\n" + << " --sm-count= Number of SMs on the chip\n" + << " --clock= Locked clock value in Mhz\n"; + + return out; + } +}; + +// +// Testbed +// + +template +class TestbedClusterLaunch { +private: + // Commandline options + OptionsClusterLaunch options; + + bool run_test() { + + // Run CuTe Gemm + Pipeline pipeline; + + bool success = false; + cudaError_t result = pipeline.run(success, this->options.grid_dim); + + CUTE_CHECK_LAST(); + return success; + } + + +public: + TestbedClusterLaunch(OptionsClusterLaunch const &options_) : options(options_) { + int device_id = 0; + cudaDeviceProp device_prop; + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + } + + /// Run verification Gemm problem sizes + bool verification() { + +#if !defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + printf( + "CUTLASS_ARCH_MMA_SM100_SUPPORTED must be set, but it is not. \n" + "This test is waived.\n" + ); + return true; +#endif + +#if 0 + bool is_success = false; + for (int i = 0; i< 10; i++){ + printf("iteration = %d\n", i); + is_success = run_test(); + if ( not is_success ) + return is_success; + } + return is_success; +#else + // Run the test with single launch + return run_test(); +#endif + } +}; diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..e44a42463ae95e4f76388d791c661de875092c93 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/kernel/reduce_splitk_testbed.h @@ -0,0 +1,45 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h new file mode 100644 index 0000000000000000000000000000000000000000..239f228831a25527106af1659383112535943df1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/reduction/thread/testbed.h @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Unit tests for thread-level Reduction +*/ + +#pragma once + +#include "cutlass/reduction/thread/reduce.h" + +#include "cutlass/layout/vector.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_compare.h" + +namespace test { +namespace reduction { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_host { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + // + // Data members + // + + cutlass::Array tensor_in; + cutlass::Array reduced_tensor_computed; + cutlass::Array reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_host() { + tensor_in.clear(); + reduced_tensor_computed.clear(); + reduced_tensor_reference.clear(); + } + + /// Runs the test + bool run() { + + // + // initialize memory + // + + for(int i = 0; i < N; i++) + tensor_in.at(i) = Element(i); + + + Reduce reduce; + + cutlass::Array *out_ptr = &reduced_tensor_computed; + out_ptr[0] = reduce(tensor_in); + + // + // Reference implementation + // + Element e(0); + for (int i = 0; i < N; i++) + e = e + Element(i); + + reduced_tensor_reference.at(0) = e; + + // + // Verify equivalence + // + + // compare + bool passed = reduced_tensor_reference[0] == reduced_tensor_computed[0]; + + EXPECT_TRUE(passed) + << "Expected = " << float(reduced_tensor_reference.at(0)) << "\n\n" + << "Actual = " << float(reduced_tensor_computed.at(0)) << "\n\n" + << std::endl; + + return passed; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level reduction kernel +template +__global__ void kernel_reduce(Element const *array_in, Element *result) { + + /// Thread-level reduction operator + using Reduce = cutlass::reduction::thread::Reduce< + cutlass::plus, + cutlass::Array + >; + + Reduce reduce; + + auto ptr_in = reinterpret_cast const *>(array_in); + auto result_ptr = reinterpret_cast *>(result); + auto in = *ptr_in; + result_ptr[0] = reduce(in); +} + + +/// Structure to compute the reduction +template < + /// Data type of elements + typename Element, + /// Number of elements + int N +> +struct Testbed_reduce_device { + + using Layout = cutlass::layout::PackedVectorLayout; + + // + // Data members + // + + cutlass::HostTensor tensor_in; + cutlass::HostTensor reduced_tensor_computed; + cutlass::HostTensor reduced_tensor_reference; + + // + // Methods + // + + /// Allocates workspace in device memory + Testbed_reduce_device() { + + tensor_in.reset(cutlass::make_Coord(N), true); + reduced_tensor_computed.reset(cutlass::make_Coord(1), true); + reduced_tensor_reference.reset(cutlass::make_Coord(1), true); + } + + + /// Runs the test + bool run() { + + // + // initialize memory + // + + cutlass::reference::host::TensorFill( + tensor_in.host_view(), + Element(1) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_computed.host_view(), + Element(0) + ); + + cutlass::reference::host::TensorFill( + reduced_tensor_reference.host_view(), + Element(N) + ); + + tensor_in.sync_device(); + reduced_tensor_computed.sync_device(); + reduced_tensor_reference.sync_device(); + + /// call the kernel + kernel_reduce<<< dim3(1, 1), dim3(1, 1, 1) >>> ( + tensor_in.device_data(), + reduced_tensor_computed.device_data() + ); + + // verify no errors + cudaError_t result = cudaDeviceSynchronize(); + + EXPECT_EQ(result, cudaSuccess) << "CUDA ERROR: " << cudaGetErrorString(result); + if (result != cudaSuccess) { + return false; + } + + // Copy back results + reduced_tensor_computed.sync_host(); + + // Verify equivalence + bool passed = cutlass::reference::host::TensorEquals( + reduced_tensor_computed.host_view(), + reduced_tensor_reference.host_view() + ); + + EXPECT_TRUE(passed) + << "Expected = " << reduced_tensor_reference.host_view() << "\n\n" + << "Actual = " << reduced_tensor_computed.host_view() << "\n\n" + << std::endl; + + return passed; + } +}; + +} // namespace thread +} // namespace reduction +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c4e7de4351076dba3a699b4cb1c8a6e01485bc20 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/sm90_sparse_gemm_compressor_legacy.hpp @@ -0,0 +1,481 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Compress utils specific for SM90 structure sparse kernels +*/ + +#pragma once + +#include // std::fill +#include // std::array +#include +#include // std::mt19937 + +#include "cute/container/bit_field.hpp" // cute::bit_field +#include "cute/numeric/numeric_types.hpp" // cute::sizeof_bits_v +#include "cute/tensor.hpp" // cute::Tensor, cute::make_tensor, cute::print_tensor +#include "cutlass/arch/arch.h" // cutlass::arch::Sm90 +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" // cutlass::TagToStrideA_t +#include "cutlass/fast_math.h" // cutlass::ceil_div, cutlass::round_up +#include "cutlass/kernel_hardware_info.h" // cutlass::KernelHardwareInfo +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/cuda_host_adapter.hpp" // cutlass::CudaHostAdapter + +namespace cutlass +{ +namespace transform +{ +namespace kernel +{ + +using namespace cute; + +namespace detail { + + template + CUTLASS_HOST_DEVICE + static uint8_t + encode_in_chunk_idx_legacy(int in_chunk_idx){ + if (sizeof(T) == 4) { + return in_chunk_idx == 0 ? 0b0100 : 0b1110; + } + else { + uint8_t res = 0; + if (in_chunk_idx == 0) { + res = 0b00; + } + else if (in_chunk_idx == 1) { + res = 0b01; + } + else if (in_chunk_idx == 2) { + res = 0b10; + } + else { + res = 0b11; + } + return res; + } + } + + template < + class SparseConfig, + class EngineA, + class LayoutA, + class EngineAc, + class LayoutAc + > + CUTLASS_HOST_DEVICE + static void + compress_two_chunks_legacy( + Tensor tensorA, + Tensor tensorAc, + uint8_t& meta_two_chunk, + int effective_elems) { + + using ElementA = typename EngineAc::value_type; + + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int ElementEBitsPerElementAMma = typename SparseConfig::ElementEBitsPerElementAMma{}; + static constexpr int LogicalSubChunk = ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalSubChunk = ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + /* + Legal metadata chunk in SM90 + Index Bin HEX + 0, 1 0b0100 4 + 1, 2 0b1001 9 + 2, 3 0b1110 E + 0, 2 0b1000 8 + 1, 3 0b1101 D + 0, 3 0b1100 C + 2, 1 0b0110 6 (Not used) + ----------------------------------- + TF32 + 0 0b0100 4 + 1 0b1110 E + */ + + if (effective_elems <= 0) { + return; + } + + // initialize + // 0 is the initial value for this function while 0x44 is the initial value for hardware. + meta_two_chunk = 0; + + for (int chunk_idx = 0; chunk_idx < 2; ++chunk_idx) { + // If Only One Chunk within this Two Chunk + if ( effective_elems <= chunk_idx * ElemsARawPerElementAMmaRaw * LogicalSubChunk ) { + break; + } + /// init result; + int non_zero_cnt = 0; + int32_t nnz_chunk_idx[PhysicalSubChunk] = { 0 }; + ElementA Ac_chunk[PhysicalSubChunk][ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + + for (int subchunk_idx = 0; subchunk_idx < LogicalSubChunk; ++subchunk_idx) { + bool is_nz = true; + ElementA subchunk_elems[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + /// Check if subchunk is non-zero + for(int elem_idx = 0; elem_idx < ElemsARawPerElementAMmaRaw; elem_idx++) { + int offset = chunk_idx * LogicalElemsAPerChunk + subchunk_idx * ElemsARawPerElementAMmaRaw + elem_idx; + subchunk_elems[elem_idx] = offset < effective_elems ? tensorA(offset) : ElementA(0); + + ElementA zero = static_cast(0); + ElementA minus_zero = static_cast(ElementA(1) << cutlass::sizeof_bits_v - 1); + if (subchunk_elems[elem_idx] != zero && subchunk_elems[elem_idx] != minus_zero) { + if (non_zero_cnt >= PhysicalSubChunk) { + #ifdef __CUDA_ARCH__ + asm volatile ("brkpt;\n" ::); + #else + throw std::runtime_error("Found extra non-zero elements in a chunk!\n"); + #endif + } + is_nz = false; + } + } + + /// There is non-zero element in the subchunk + if(!is_nz) { + nnz_chunk_idx[non_zero_cnt] = subchunk_idx; + memcpy(Ac_chunk[non_zero_cnt], subchunk_elems, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + non_zero_cnt++; + } + } + + /* + Special cases + nnz == 1 and non-tf32 and nnz_idx = 3 + */ + ElementA elementA_zeros[ElemsARawPerElementAMmaRaw] = { ElementA{0} }; + if constexpr (sizeof_bits_v < 32) { + if (non_zero_cnt == 1 && nnz_chunk_idx[0] == 3) { + memcpy(Ac_chunk[1], Ac_chunk[0], sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + memcpy(Ac_chunk[0], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + nnz_chunk_idx[0] = 0; + } + else if (non_zero_cnt == 1) { + memcpy(Ac_chunk[1], elementA_zeros, sizeof(ElementA) * ElemsARawPerElementAMmaRaw); + nnz_chunk_idx[1] = 3; + } + } + + /// Setup metadata + uint8_t meta_chunk = 0; + for (int i = 0; i < PhysicalSubChunk; i++) { + meta_chunk = static_cast(meta_chunk | (encode_in_chunk_idx_legacy(nnz_chunk_idx[i]) << (i * ElementEBitsPerElementAMma))); + for(int j = 0; j < ElemsARawPerElementAMmaRaw; j++) { + tensorAc(chunk_idx * PhysicalElemsAPerChunk + i * ElemsARawPerElementAMmaRaw + j) = Ac_chunk[i][j]; + } + } + meta_two_chunk = uint8_t(meta_two_chunk | (meta_chunk << (chunk_idx * _4{}))); + } + } +} + +template< + class ProblemShape_, + class ElementA_, + class LayoutATag_, + class SparseConfig_ +> +class SM90StructuredSparseCompressorLegacy { +public: + using SparseConfig = SparseConfig_; + using ProblemShape = ProblemShape_; + + // * EltA + using ElementA = ElementA_; + using ElementAUint = cute::uint_bit_t>; + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + using ArrayElementA = cute::conditional_t>, + ElementA>; + using ElementAMma = typename SparseConfig::ElementAMma; + using ElementAMmaRaw = typename SparseConfig::ElementAMmaRaw; + using ElementASparsity = typename SparseConfig::ElementASparsity; + using ElementAMmaSparsity = typename SparseConfig::ElementAMmaSparsity; + using LayoutATag = LayoutATag_; + using LayoutA = LayoutATag; + using StrideA = cutlass::gemm::TagToStrideA_t; + + // * EltE + using ElementEMma = typename SparseConfig::ElementEMma; + using ElementEMmaRaw = typename SparseConfig::ElementEMmaRaw; + using ElementEMmaSparsity = typename SparseConfig::ElementEMmaSparsity; + + // * AtomE + using TensorEAtom = typename SparseConfig::TensorEAtom; + using TensorEAtomK = typename SparseConfig::TensorEAtomK; + using TensorEAtomM = typename SparseConfig::TensorEAtomM; + + static constexpr int ElemsARawPerElementAMmaRaw = typename SparseConfig::ElemsARawPerElementAMmaRaw{}; + static constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + static constexpr int PhysicalElemsAPerChunk = typename SparseConfig::PhysicalElemsAPerChunk{}; + static constexpr int LogicalElemsAMmaRawPerChunk = cutlass::ceil_div(LogicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + static constexpr int PhysicalElemsAMmaRawPerChunk = cutlass::ceil_div(PhysicalElemsAPerChunk, ElemsARawPerElementAMmaRaw); + + // * Alignment + static constexpr int TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + static constexpr int TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + static constexpr int TensorAAlignmentK = typename SparseConfig::TensorAAlignmentK{}; + static constexpr int TensorAAlignmentM = typename SparseConfig::TensorAAlignmentM{}; + + // Required by `device_kernel` + static constexpr int MaxThreadsPerBlock = 1; + static constexpr int MinBlocksPerMultiprocessor = 1; + using ArchTag = arch::Sm90; + + struct SharedStorage { + /* empty, no smem needed */ + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct TransformArguments { + ArrayElementA const* ptr_A{nullptr}; + StrideA dA{}; + ArrayElementA* ptr_ACompress{nullptr}; + ElementEMmaRaw* ptr_E{nullptr}; + }; + + using TransformParams = TransformArguments; + + struct Arguments { + ProblemShape problem_shape{}; + TransformArguments transform{}; + KernelHardwareInfo hw_info{}; + }; + + struct Params { + ProblemShape problem_shape{}; + TransformParams transform{}; + KernelHardwareInfo hw_info{}; + void* workspace = nullptr; + }; + + static Params + to_underlying_arguments(Arguments & args, void* workspace) { + return Params{{args.problem_shape}, + {args.transform.ptr_A, args.transform.dA, args.transform.ptr_ACompress, args.transform.ptr_E}, + {args.hw_info}, + workspace}; + } + + static Status + can_implement(Arguments const& args) { + auto [M, N, K, L] = args.problem_shape; + if (K % LogicalElemsAPerChunk != 0) { + CUTLASS_TRACE_HOST("SM90 Sparse Compressor CAN NOT IMPLEMENT: GemmK not multiplier of logical chunk size\n"); + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + static size_t + get_workspace_size(Arguments const& args) { + auto problem = args.problem_shape; + const int m = cute::size<0>(problem); + const int k = cute::size<2>(problem); + const int l = cute::size<3>(problem); + const int metadata_k = round_up(k, TensorEAlignmentK); + const int metadata_m = round_up(m, TensorEAlignmentM); + const int metadata_bytes = metadata_m * metadata_k / ElementEMmaSparsity{} * l; + return metadata_bytes; + } + + static Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + cudaError_t cuda_error; + + auto workspace_size = get_workspace_size(args); + if (workspace_size == 0) { + return Status::kSuccess; + } else if (workspace == nullptr) { + return Status::kErrorInternal; + } + + cudaPointerAttributes attri; + cuda_error = cudaPointerGetAttributes(&attri, workspace); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } + + if ( attri.type == cudaMemoryTypeDevice ) { +#if defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) && CUTLASS_ENABLE_CUDA_HOST_ADAPTER + CUTLASS_ASSERT(cuda_adapter); + if (Status::kSuccess != cuda_adapter->memsetDevice(workspace, static_cast(0), workspace_size, stream)) { + return Status::kErrorInternal; + } +#else + cudaMemsetAsync(workspace, 0, workspace_size, stream); + cuda_error = cudaGetLastError(); + if (cuda_error != cudaSuccess) { + return Status::kErrorInternal; + } +#endif + } else { + memset(workspace, 0, workspace_size); + } + + return Status::kSuccess; + } + + static dim3 + get_grid_shape(Params const& params) { + return dim3(1, 1, 1); + } + + static dim3 + get_block_shape() { + return dim3(1, 1, 1); + } + + CUTE_HOST_DEVICE + void + operator()(Params params, char* smem_buf = nullptr) { + run(params, smem_buf); + } + + CUTE_HOST_DEVICE + static void + run(Params params, char* smem_buf = nullptr) { + do_compress_device_host(params); + } + +private: + + CUTE_HOST_DEVICE + static void + do_compress_device_host(Params params) { + auto [m, n, k, l] = params.problem_shape; + auto [ptr_A, dA, ptr_ACompress, ptr_E] = params.transform; + auto workspace = params.workspace; + + const int aligned_k = (k + TensorAAlignmentK - 1) / TensorAAlignmentK * TensorAAlignmentK; + const int aligned_m = (m + TensorAAlignmentM - 1) / TensorAAlignmentM * TensorAAlignmentM; + const int metadata_k = (k + TensorEAlignmentK - 1) / TensorEAlignmentK * TensorEAlignmentK; + const int metadata_m = (m + TensorEAlignmentM - 1) / TensorEAlignmentM * TensorEAlignmentM; + const int k_compressed = aligned_k / ElementASparsity{}; + + // Convert to CuTe tensors. But don't want to use sparse_ptr, which is making everything complicated here. + cute::Tensor tensorA = make_tensor(recast_ptr(ptr_A), make_layout(make_shape(m, k, l), dA)); + + cute::Tensor tensorAc = make_tensor(recast_ptr(ptr_ACompress), + make_shape(aligned_m, k_compressed, l), + make_cute_packed_stride(StrideA{}, cute::make_shape(aligned_m, k_compressed, l))); + + cute::Tensor tensorE_raw_compress_logical = make_tensor(recast_ptr>(workspace), + make_shape(metadata_m, make_shape(TensorEAtomK{}, metadata_k / TensorEAtomK{}), l), + make_stride(TensorEAtomK{}, make_stride(_1{}, metadata_m*TensorEAtomK{}), metadata_m*metadata_k)); + + cute::Tensor tensorE_raw_compress = recast(tensorE_raw_compress_logical); + + // The following vars are all logical. + int atom_m = size<0>(TensorEAtom{}); + int atom_k = size<1>(TensorEAtom{}); + int tiled_m = metadata_m / atom_m; + int tiled_ke = metadata_k / atom_k; + // Col major when viewing atoms + int stride_tile_m = cosize(TensorEAtom{}); + int stride_tile_ke = atom_k * metadata_m; + + // Logical metadata tensor + cute::Tensor tensorE_logical = make_tensor(recast_ptr>(ptr_E), + make_layout(make_shape(append(shape<0>(TensorEAtom{}), tiled_m), + append(shape<1>(TensorEAtom{}), tiled_ke), + shape<2>(tensorE_raw_compress_logical)), + make_stride(append(stride<0>(TensorEAtom{}), stride_tile_m), + append(stride<1>(TensorEAtom{}), stride_tile_ke), + stride<2>(tensorE_raw_compress_logical)))); + // Physical metadata tensor + cute::Tensor tensorE = recast(tensorE_logical); + + // void do_init() + cute::clear(tensorAc); + cute::clear(tensorE_raw_compress); + + // void do_raw_compress() + using TileStepA = Int; + using TileStepAc = Int; + + cute::Tensor tensorATiled = logical_divide(tensorA, make_shape(_, TileStepA{}, _)); + cute::Tensor tensorAcTiled = logical_divide(tensorAc, make_shape(_, TileStepAc{}, _)); + + for (int batch_idx = 0; batch_idx < l; batch_idx++) { + for (int m_idx = 0; m_idx < m; m_idx++) { + for (int tiler_k_idx = 0; tiler_k_idx < size<1,1>(tensorATiled); tiler_k_idx++) { + int effective_elems = cute::min(TileStepA{}, k - (tiler_k_idx * TileStepA{})); + detail::compress_two_chunks_legacy(tensorATiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorAcTiled(m_idx, make_coord(_, tiler_k_idx), batch_idx), + tensorE_raw_compress(m_idx, tiler_k_idx, batch_idx), + effective_elems); + } + } + } + + // void do_reorder() + // Fast path when we don't permute. + if constexpr (sizeof_bits_v <= 8) { + memcpy(tensorE.data(), tensorE_raw_compress.data(), tensorE.size()); + } + else { + cute::copy(tensorE_raw_compress, tensorE); + } + + #if 0 + print("--> TensorA\n"); + auto tensorA_eltA = cute::recast(tensorA); + cute::print_tensor(tensorA_eltA); printf("\n\n"); + + print("--> REF TensorAC\n"); + auto tensorAc_eltA = cute::recast(tensorAc); + cute::print_tensor(tensorAc_eltA); printf("\n\n"); + + print("--> REF TensorE\n"); + cute::print_tensor(tensorE); printf("\n\n"); + #endif + + } +}; + +} // namespace kernel +} // namespace transform +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f44458244e0d3c4c80ecc29a0115cd6906211559 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/test/unit/transform/device/testbed_sparse_gemm_compressor.hpp @@ -0,0 +1,877 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* + * @brief Test for structured sparse gemm compressor device kernel + */ + +#pragma once + +#include // cudaGetLastError + +#include // uint64_t +#include // printf +#include // malloc +#include // std::cout +#include +#include + +#include "cute/layout.hpp" // cute::make_shape +#include "cute/util/type_traits.hpp" // cute::is_same_v +#include "cutlass/coord.h" // cutlass::make_Coord +#include "cutlass/cutlass.h" // cutlass::Status +#include "cutlass/kernel_hardware_info.hpp" // cutlass::KernelHardwareInfo +#include "cutlass/layout/matrix.h" // cutlass::layout::Affine2Layout_Factory +#include "cutlass/numeric_types.h" // cutlass::sizeof_bits, cutlass::float_ +#include "cutlass/tensor_view.h" // cutlass::TensorView +#include "cutlass/transform/device/transform_universal_adapter.hpp" // cutlass::transform::device::TransformUniversalAdapter +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // cutlass::transform::kernel::StructuredSparseCompressorUtility +#include "cutlass/util/device_memory.h" // cutlass::device_memory::allocation +#include "cutlass/util/distribution.h" // cutlass::Distribution +#include "cutlass/util/host_tensor.h" // cutlass::HostTensor +#include "cutlass/util/packed_stride.hpp" // cutlass::make_cute_packed_stride +#include "cutlass/util/reference/host/tensor_compare.h" // cutlass::reference::host::TensorEquals +#include "cutlass/util/reference/host/tensor_fill.h" // cutlass::reference::host::TensorFillRandomUniform, TensorFillIdentity, TensorFillRandomGaussian, BlockFillSequential, TensorFill +#include "cutlass/detail/collective.hpp" + +#include "sm90_sparse_gemm_compressor_legacy.hpp" // Legacy host compressor +#include "../../common/cutlass_unit_test.h" // CUTLASS UT, EXPECT_TRUE + + +#define CUDA_CHECK_FALSE(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return false; \ + } \ + } + +#define CUDA_CHECK(cuda_error) \ + { \ + if (cuda_error != cudaSuccess) { \ + printf("cudaError %s in %s:%d\n", cudaGetErrorString(cuda_error), __func__, __LINE__ ); \ + return; \ + } \ + } + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// * Test Bed +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace test +{ +namespace transform +{ +namespace device +{ + +// Helper Functions +template +bool +initialize_tensor(cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) +{ + if (dist_kind == cutlass::Distribution::Uniform) { + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 1; + scope_min = -1; + } else { + scope_max = 4; + scope_min = -4; + } + cutlass::reference::host::TensorFillRandomUniform(view, seed, scope_max, scope_min, 0); + } + + else if (dist_kind == cutlass::Distribution::Identity) { + cutlass::reference::host::TensorFillIdentity(view); + } + + else if (dist_kind == cutlass::Distribution::Gaussian) { + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + + else { + EXPECT_TRUE(false) << "Not implemented"; + return false; + } + + return true; +} + +// Testbed +template +struct TestbedSparseGemmCompressor { +public: + using Compressor = Compressor_; + using CompressorKernel = typename Compressor::TransformKernel; + + using ElementA = typename CompressorKernel::ElementA; + using LayoutATag = typename CompressorKernel::LayoutATag; + using StrideA = typename CompressorKernel::StrideA; + using ArrayElementA = + ElementA + ; + + using ElementE = typename CompressorKernel::ElementEMmaRaw; + using LayoutETag = cutlass::layout::RowMajor; // We don't care about the major here, just to allocate tensor + + using SparseConfig = typename CompressorKernel::SparseConfig; + using ProblemShapeType = typename CompressorKernel::ProblemShape; + + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorKernelHost = cutlass::transform::kernel::SM90StructuredSparseCompressorLegacy< + ProblemShapeType, + ElementA, + LayoutATag, + SparseConfig>; + + using CompressorHost = cutlass::transform::device::TransformUniversalAdapter; + + static constexpr auto LogicalElemsAPerChunk = CompressorKernel::LogicalElemsAPerChunk; + static constexpr auto PhysicalElemsAPerChunk = CompressorKernel::PhysicalElemsAPerChunk; + + struct Data { + // Data Storage + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_A_Comp; + cutlass::HostTensor tensor_E; + cutlass::HostTensor tensor_A_Comp_ref; + cutlass::HostTensor tensor_E_ref; + }; + + struct CudaRAII { + cudaStream_t stream; + cudaEvent_t start; + cudaEvent_t stop; + + CudaRAII(){ + CUDA_CHECK(cudaStreamCreate( &stream )); + CUDA_CHECK(cudaEventCreate( &start )); + CUDA_CHECK(cudaEventCreate( &stop )); + }; + + CudaRAII(const CudaRAII&) = delete; + CudaRAII& operator=(const CudaRAII&) = delete; + CudaRAII(CudaRAII&&) = delete; + CudaRAII& operator=(CudaRAII&&) = delete; + + ~CudaRAII(){ + CUDA_CHECK(cudaStreamDestroy( stream )); + CUDA_CHECK(cudaEventDestroy( start )); + CUDA_CHECK(cudaEventDestroy( stop )); + } + }; + +public: + TestbedSparseGemmCompressor( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_E_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_Comp_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 7) + : init_A(init_A_) + , init_E(init_E_) + , init_A_Comp(init_A_Comp_) + , seed(seed_) + { + } + + bool valid_test(ProblemShapeType problem_shape_MNKL) + { + const int GemmK = cute::size<2>(problem_shape_MNKL); + + if ( GemmK % LogicalElemsAPerChunk != 0 ) { + printf("GemmK needs to be multiplier of LogicalElemsAPerChunk\n"); + return false; + } + + return true; + } + + bool initialize(ProblemShapeType problem_shape_MNKL, Data& datas) + { + CUDA_CHECK_FALSE(cudaGetLastError()); + + // In unit of ElementARaw + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + // Compressor utility to get allocated data size + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + CompressorUtility compressor_utility(problem_shape_MNKL, stride_a); + + // TensorA + // In unit of ElementARaw, after alignment requirement + // M-dim: no alignment requirement + // K-dim: multiplier of chunk size + + // TensorA Compressed + // In unit of ElementARaw, after alignment requirement + // M-dim: TMA alignment + // K-dim: TMA alignment + const int GemmMAlignedAC = compressor_utility.get_tensorA_m_physical(); + const int GemmKAlignedAC = compressor_utility.get_tensorA_k_physical(); + + // TensorE + // In unit of ElementE (uint8_t), after alignment requirement + // M-dim: TensorEAtom_M alignment + // K-dim: TensorEAtom_K alignment + const int GemmMAlignedE = compressor_utility.get_metadata_m_physical(); + const int GemmKAlignedE = compressor_utility.get_metadata_k_physical(); + + auto a_coord = cutlass::make_Coord(GemmM * GemmL, GemmK); + auto e_coord = cutlass::make_Coord(GemmMAlignedE * GemmL, GemmKAlignedE); + auto a_comp_coord = cutlass::make_Coord(GemmMAlignedAC * GemmL, GemmKAlignedAC); + + typename LayoutATag::Stride stride_factor_A; + typename LayoutETag::Stride stride_factor_E; + + datas.tensor_A.resize(a_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); + datas.tensor_A_Comp.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A)); + datas.tensor_A_Comp_ref.resize(a_comp_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(a_comp_coord, stride_factor_A), + false); + datas.tensor_E.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E)); + datas.tensor_E_ref.resize(e_coord, + cutlass::layout::Affine2Layout_Factory::layout_factory(e_coord, stride_factor_E), + false); + + EXPECT_TRUE(initialize_tensor(datas.tensor_A.host_view(), init_A, seed + 1)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E.host_view(), init_E, seed + 2)); + EXPECT_TRUE(initialize_tensor(datas.tensor_E_ref.host_view(), init_E, seed + 3)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp.host_view(), init_A_Comp, seed + 4)); + EXPECT_TRUE(initialize_tensor(datas.tensor_A_Comp_ref.host_view(), init_A_Comp, seed + 5)); + + compressor_utility.structure_sparse_zero_mask_fill(datas.tensor_A.host_data(), seed + 6); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + datas.tensor_A.sync_device(); + datas.tensor_A_Comp.sync_device(); + datas.tensor_E.sync_device(); + + // Check for failed devide + CUDA_CHECK_FALSE(cudaGetLastError()); + + return true; + } + + bool run_device(ProblemShapeType problem_shape_MNKL, Data& datas, float* time = nullptr) + { + CudaRAII cuda_raii; + + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.device_data(), + stride_a, + datas.tensor_A_Comp.device_data(), + datas.tensor_E.device_data()}, + {hw_info} + }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status {cutlass::Status::kSuccess }; + + status = compressor_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + status = compressor_op.initialize(arguments, workspace.get(), cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.start, cuda_raii.stream)); + + status = compressor_op.run(cuda_raii.stream); + if (status != cutlass::Status::kSuccess) { + CUDA_CHECK_FALSE(cudaGetLastError()); + } + + CUDA_CHECK_FALSE(cudaEventRecord(cuda_raii.stop, cuda_raii.stream)); + CUDA_CHECK_FALSE(cudaEventSynchronize(cuda_raii.stop)); + CUDA_CHECK_FALSE(cudaStreamSynchronize(cuda_raii.stream)); + if ( time != nullptr ){ + CUDA_CHECK_FALSE(cudaEventElapsedTime(time, cuda_raii.start, cuda_raii.stop)); + } + + datas.tensor_A_Comp.sync_host(); + datas.tensor_E.sync_host(); + + #if 0 + { + printf("\n--> DEVICE OUTPUT\n"); + printf("datas.tensor_A\n"); + std::cout << datas.tensor_A.host_view() << std::endl << std::endl; + printf("datas.tensor_A_Comp\n"); + std::cout << datas.tensor_A_Comp.host_view() << std::endl << std::endl; + printf("datas.tensor_E\n"); + std::cout << datas.tensor_E.host_view() << std::endl << std::endl; + } + #endif + + return true; + } + + bool run_host_ref(ProblemShapeType problem_shape_MNKL, Data& datas) + { + const int GemmM = cute::size<0>(problem_shape_MNKL); + const int GemmN = cute::size<1>(problem_shape_MNKL); + const int GemmK = cute::size<2>(problem_shape_MNKL); + const int GemmL = cute::size<3>(problem_shape_MNKL); + + StrideA stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(GemmM, GemmK, GemmL)); + + typename CompressorKernelHost::Arguments arguments{ + {GemmM, GemmN, GemmK, GemmL}, + {datas.tensor_A.host_data(), + stride_a, + datas.tensor_A_Comp_ref.host_data(), + datas.tensor_E_ref.host_data()}, + {}}; + + const auto can_imp = CompressorKernelHost::can_implement(arguments); + if (can_imp != cutlass::Status::kSuccess) { + printf("can_implement() check failed\n"); + return false; + } + + // Relies on std::vector for RAII + auto workspace_size = + static_cast::size_type>(CompressorKernelHost::get_workspace_size(arguments)); + std::vector workspace_vector(workspace_size); + auto workspace = static_cast(workspace_vector.data()); + + cutlass::Status status = CompressorKernelHost::initialize_workspace(arguments, workspace); + if (status != cutlass::Status::kSuccess) { + printf("initialize_workspace() failed\n"); + return false; + } + + auto params = CompressorKernelHost::to_underlying_arguments(arguments, workspace); + CompressorKernelHost::run(params); + + return true; + } + + bool compare_reference(Data& datas) + { + bool check_tensor_a_compressed = + cutlass::reference::host::TensorEquals(datas.tensor_A_Comp_ref.host_view(), datas.tensor_A_Comp.host_view()); + if (!check_tensor_a_compressed) { + printf("A-Compressed Mismatch\n"); + } + + bool check_tensor_e = cutlass::reference::host::TensorEquals(datas.tensor_E_ref.host_view(), datas.tensor_E.host_view()); + if (!check_tensor_e) { + printf("E Mismatch\n"); + } + + return check_tensor_a_compressed && check_tensor_e; + } + + bool run_auto_small() + { + return run_auto(true); + } + + bool run_auto(bool run_small = false) + { + constexpr auto TensorEAlignmentM = typename SparseConfig::TensorEAlignmentM{}; + constexpr auto TensorEAlignmentK = typename SparseConfig::TensorEAlignmentK{}; + constexpr int LogicalElemsAPerChunk = typename SparseConfig::LogicalElemsAPerChunk{}; + + constexpr int GemmN = 1; + + using ProblemType = typename std::array; + + std::vector problems; + + const std::vector problems_multiplier_of_tensor_e_atom = { + // * Regular Cases (multiplier of TensorEAlignment) + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 1}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 1}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 2}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 2}, + + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 1, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 2, GemmN, TensorEAlignmentK * 3, 3}, + + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 2, 3}, + {TensorEAlignmentM * 3, GemmN, TensorEAlignmentK * 3, 3}, + }; + + const std::vector problems_multiplier_of_tensor_e_atom_large = { + // * Large Case (multiplier of TensorEAlignment) + {TensorEAlignmentM * 10, GemmN, TensorEAlignmentK * 13, 1}, + // {TensorEAlignmentM * 11, GemmN, TensorEAlignmentK * 14, 2}, + // {TensorEAlignmentM * 12, GemmN, TensorEAlignmentK * 15, 3}, + }; + + const std::vector problems_multiplier_of_twochunk { + // * Corner Cases + {4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 6, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 2, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 4, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 6, 3}, + }; + + const std::vector problems_multiplier_of_onechunk { + {4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {32 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {32 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 1}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 1}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 2}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 2}, + + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK + LogicalElemsAPerChunk * 5, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 1, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 3, 3}, + {TensorEAlignmentM * 2 + 4, GemmN, TensorEAlignmentK * 2 + LogicalElemsAPerChunk * 5, 3}, + }; + + // Run small only run multiplier of chunk size cases + if (run_small) { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + } + // Run full run all corner cases + else { + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom_large.begin(), problems_multiplier_of_tensor_e_atom_large.end()); + problems.insert(problems.end(), problems_multiplier_of_tensor_e_atom.begin(), problems_multiplier_of_tensor_e_atom.end()); + problems.insert(problems.end(), problems_multiplier_of_twochunk.begin(), problems_multiplier_of_twochunk.end()); + problems.insert(problems.end(), problems_multiplier_of_onechunk.begin(), problems_multiplier_of_onechunk.end()); + } + + for (const auto& problem_shape_MNKL : problems) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + bool passed = run({GemmM, GemmN, GemmK, GemmL}); + printf("run() (%.4d,%.4d,%.4d,%.4d) %s\n", GemmM, GemmN, GemmK, GemmL, passed ? "PASS" : "FAIL"); + CUTLASS_TRACE_HOST("run() " << GemmM << " " << GemmN << " " << GemmK << " " << GemmL << passed ? " PASS" : " FAIL"); + if (not passed) { + return false; + } + } + + return true; + } + + bool run(ProblemShapeType problem_shape_MNKL) + { + // Check if valid test + if (not valid_test(problem_shape_MNKL)) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // Data Storage + Data datas; + + // Initialize Data + if (not initialize(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Host Ref) + if (not run_host_ref(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_host() fail\n"); + return false; + } + + // Run Compressor (Device) + if (not run_device(problem_shape_MNKL, datas)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + // Verify + if (not compare_reference(datas)) { + CUTLASS_TRACE_HOST("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + printf("compare_reference() DEVICE <-> LEGACY HOST fail\n"); + return false; + } + // else { + // printf("DEVICE <-> HOST PASS\n"); + // } + + return true; + } + + bool benchmark(ProblemShapeType problem_shape_MNKL) { + const auto [GemmM, GemmN, GemmK, GemmL] = problem_shape_MNKL; + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) START\n", GemmM, GemmN, GemmK, GemmL); + + // Check if valid test + if (valid_test(problem_shape_MNKL) == false) { + CUTLASS_TRACE_HOST("valid_test() fail\n"); + return false; + } + + // 2 warm-up iterations and 10 timing iterations + constexpr int num_warmup = 5; + constexpr int num_iter = 10; + + // Duplicate data to mimic cold cache + Data data[num_warmup + num_iter]; + double total_time_milliseconds{0.0}; + + for (int i = 0; i < num_warmup + num_iter; ++i ) { + printf("Benchmark() (%.4d,%.4d,%.4d,%.4d) ITER %d\n", GemmM, GemmN, GemmK, GemmL, i ); + + auto& datum_i = data[i]; + + // Initialize Data + if (initialize(problem_shape_MNKL, datum_i) == false) { + CUTLASS_TRACE_HOST("initialize() fail\n"); + return false; + } + + // Run Compressor (Device) + double time_i_milliseconds{0.0f}; + if (not run_device(problem_shape_MNKL, datum_i, &time_i_milliseconds)) { + CUTLASS_TRACE_HOST("run_device() fail\n"); + return false; + } + + if ( i >= num_warmup ) { + total_time_milliseconds += time_i_milliseconds; + } + } + + const double mean_time_milliseconds = total_time_milliseconds / num_iter; + printf("Mean time (ms): %.5f\n", mean_time_milliseconds); + + return true; + } + +public: + // Data Init Setting + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_A_Comp; + cutlass::Distribution::Kind init_E; + uint64_t seed; +}; + +} // namespace device +} // namespace transform +} // namespace test diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h new file mode 100644 index 0000000000000000000000000000000000000000..df241e3ca6e6e584af7351402d990a8028e2abed --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/arch_mappings.h @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/arch/mma.h" +#include "cutlass/arch/arch.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ArchMap; + +template <> struct ArchMap { + static int const kMin = 50; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 60; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 61; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 70; + static int const kMax = 75; +}; + +template struct ArchMap { + static int const kMin = 75; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 80; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 86; + static int const kMax = 1024; +}; + +template struct ArchMap { + static int const kMin = 89; + static int const kMax = 100; +}; + +template struct ArchMap { + static int const kMin = 90; + static int const kMax = 1024; +}; + +// Arch conditional WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + +// Arch conditional sparse WGMMA +template <> struct ArchMap { + static int const kMin = 90; + static int const kMax = 90; +}; + + +template struct ArchMap { + static int const kMin = 100; + static int const kMax = 1024; +}; + +template <> struct ArchMap { + static int const kMin = 100; + #if (__CUDACC_VER_MAJOR__ >= 13) + static int const kMax = 110; + #else + static int const kMax = 103; + #endif // __CUDACC_VER_MAJOR__ >= 13 +}; + +template struct ArchMap { + static int const kMin = 103; + static int const kMax = 1024; +}; +template <> struct ArchMap { + static int const kMin = 103; + static int const kMax = 103; +}; + +template struct ArchMap { + static int const kMin = 120; + static int const kMax = 121; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h new file mode 100644 index 0000000000000000000000000000000000000000..5e80c124e59d24cd90c7c1b0c06bcc3bedfee62f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/descriptions.h @@ -0,0 +1,815 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct MathInstructionDescription { + + /// Shape of the target math instruction + cutlass::gemm::GemmCoord instruction_shape; + + /// Describes the data type of the internal accumulator + NumericTypeID element_accumulator; + + /// Classification of math instruction + OpcodeClassID opcode_class; + + /// Type of math operation performed + MathOperationID math_operation; + + // + // Methods + // + + MathInstructionDescription( + cutlass::gemm::GemmCoord instruction_shape = cutlass::gemm::GemmCoord(), + NumericTypeID element_accumulator = NumericTypeID::kInvalid, + OpcodeClassID opcode_class = OpcodeClassID::kInvalid, + MathOperationID math_operation = MathOperationID::kMultiplyAdd + ): + instruction_shape(instruction_shape), + element_accumulator(element_accumulator), + opcode_class(opcode_class), + math_operation(math_operation) {} + + // Equality operator + inline + bool operator==(MathInstructionDescription const& rhs) const{ + return ( + (instruction_shape == rhs.instruction_shape) && + (element_accumulator == rhs.element_accumulator) && + (opcode_class == rhs.opcode_class) && + (math_operation == rhs.math_operation)); + } + + // Inequality operator + inline + bool operator!=(MathInstructionDescription const& rhs) const { + return !(*this == rhs); + } + +}; + +/// Structure describing the tiled structure of a GEMM-like computation +struct TileDescription { + + /// Describes the shape of a threadblock (in elements) + cutlass::gemm::GemmCoord threadblock_shape; + + /// Describes the number of pipeline stages in the threadblock-scoped mainloop + int threadblock_stages; + + /// Number of warps in each logical dimension + cutlass::gemm::GemmCoord warp_count; + + /// Core math instruction + MathInstructionDescription math_instruction; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int minimum_compute_capability; + + /// Minimum compute capability (e.g. 70, 75) of a device eligible to run the operation. + int maximum_compute_capability; + + /// Describes the shape of a cluster (in blocks) + cutlass::gemm::GemmCoord cluster_shape; + + // + // Methods + // + + TileDescription( + cutlass::gemm::GemmCoord threadblock_shape = cutlass::gemm::GemmCoord(), + int threadblock_stages = 0, + cutlass::gemm::GemmCoord warp_count = cutlass::gemm::GemmCoord(), + MathInstructionDescription math_instruction = MathInstructionDescription(), + int minimum_compute_capability = 0, + int maximum_compute_capability = 0, + cutlass::gemm::GemmCoord cluster_shape = cutlass::gemm::GemmCoord(1,1,1) + ): + threadblock_shape(threadblock_shape), + threadblock_stages(threadblock_stages), + warp_count(warp_count), + math_instruction(math_instruction), + minimum_compute_capability(minimum_compute_capability), + maximum_compute_capability(maximum_compute_capability), + cluster_shape(cluster_shape) { } + + // Equality operator + inline + bool operator==(TileDescription const& rhs) const{ + return ( + (threadblock_shape == rhs.threadblock_shape) && + (threadblock_stages == rhs.threadblock_stages) && + (warp_count == rhs.warp_count) && + (math_instruction == rhs.math_instruction) && + (minimum_compute_capability == rhs.minimum_compute_capability) && + (maximum_compute_capability == rhs.maximum_compute_capability)); + } + + // Inequality operator + inline + bool operator!=(TileDescription const& rhs) const { + return !(*this == rhs); + } +}; + +/// High-level description of an operation +struct OperationDescription { + + /// Unique identifier describing the operation + char const * name; + + /// Operation provider + Provider provider; + + /// Kind of operation + OperationKind kind; + + /// Describes the tiled structure of a GEMM-like computation + TileDescription tile_description; + + // + // Methods + // + OperationDescription( + char const * name = "unknown", + Provider provider = Provider::kInvalid, + OperationKind kind = OperationKind::kInvalid, + TileDescription const& tile_description = TileDescription() + ): + name(name), provider(provider), kind(kind), tile_description(tile_description) { } +}; + +/// Structure describing the properties of a tensor +struct TensorDescription { + + /// Numeric type of an individual element + NumericTypeID element; + + /// Enumerant identifying the layout function for the tensor + LayoutTypeID layout; + + /// Alignment restriction on pointers, strides, and extents + int alignment; + + /// log2() of the maximum extent of each dimension + int log_extent_range; + + /// log2() of the maximum value each relevant stride may have + int log_stride_range; + + // + // Methods + // + + TensorDescription( + NumericTypeID element = NumericTypeID::kInvalid, + LayoutTypeID layout = LayoutTypeID::kInvalid, + int alignment = 1, + int log_extent_range = 24, + int log_stride_range = 24 + ): + element(element), + layout(layout), + alignment(alignment), + log_extent_range(log_extent_range), + log_stride_range(log_stride_range) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all GEMM computations +struct GemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the sparse meta matrices + TensorDescription E; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + GemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + GemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +struct BlockScaleDescription { + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + /// Describes the underlying kind of scaling: + /// Tensor Core supported (BlockScaled) or manual scaling (Blockwise) + OperationKind kind; +}; + +struct GroupedGemmDescription : public OperationDescription { + GemmDescription gemm; + std::optional block_scales; +}; + +/// Description of all GEMM computations +struct BlockScaledGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the SFD operand + TensorDescription SFD; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFVecSize; + + /// Describes the Output ScaleFactor VectorSize + int EpilogueSFVecSize; + + // + // Methods + // + + BlockScaledGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockScaledGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +/// Description of all GEMM computations +struct BlockwiseGemmDescription : public OperationDescription { + + /// Indicates the kind of GEMM performed + GemmKind gemm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source matrix + TensorDescription C; + + /// Describes the destination matrix + TensorDescription D; + + /// Describes the SFA operand + TensorDescription SFA; + + /// Describes the SFB operand + TensorDescription SFB; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + /// Describes the input ScaleFactor VectorSize + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + + // + // Methods + // + + BlockwiseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} + + BlockwiseGemmDescription( + OperationDescription op_desc, + GemmKind gemm_kind, + TensorDescription const& A, + TensorDescription const& B, + TensorDescription const& C, + TensorDescription const& D, + NumericTypeID element_epilogue, + SplitKMode split_k_mode, + ComplexTransform transform_A, + ComplexTransform transform_B + ): + OperationDescription(op_desc), + gemm_kind(gemm_kind), + A(A), + B(B), + C(C), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description for structured sparse GEMMs. +struct SparseGemmDescription : public GemmDescription { + + /// Description structure for structured sparse GEMM + SparseGemmDescription( + GemmKind gemm_kind = GemmKind::kGemm, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + TensorDescription const& D = TensorDescription(), + TensorDescription const& E = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + GemmDescription(gemm_kind, A, B, C, D, element_epilogue, split_k_mode, transform_A, transform_B) + {this->E = E;} +}; + +/// Description of all Reduction operations +struct ReductionDescription : public OperationDescription { + + /// Describes the data type of workspace + NumericTypeID element_workspace; + + /// Describes the data type of final output + NumericTypeID element_output; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; +}; + +/// Description of all Rank K update computations (SYRK, HERK, SYR2K, HER2K) +struct RankKDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + RankKKind rank_k_kind; + + /// Number of rank update (rank k or rank 2k) + int num_ranks; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand (used only for SYR2K and HER2K) + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the fill mode for matrix C + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + RankKDescription( + RankKKind rank_k_kind = RankKKind::kUniversal, + int num_ranks = 1, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + rank_k_kind(rank_k_kind), + num_ranks(num_ranks), + A(A), + B(B), + C(C), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all TRMM computations +struct TrmmDescription : public OperationDescription { + + /// Indicates the kind of TRMM performed + TrmmKind trmm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the diag type for matrix A + DiagType diag_type; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription D; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + // + // Methods + // + + TrmmDescription( + TrmmKind trmm_kind = TrmmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + DiagType diag_type = DiagType::kInvalid, + TensorDescription const& B = TensorDescription(), + TensorDescription const& D = TensorDescription(), + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone + ): + trmm_kind(trmm_kind), + A(A), + side_mode(side_mode), + fill_mode(fill_mode), + diag_type(diag_type), + B(B), + D(D), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all SYMM/HEMM update computations +struct SymmDescription : public OperationDescription { + + /// Indicates which device template is used (universal or regular) + SymmKind symm_kind; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the source and destination matrices + TensorDescription C; + + /// Describes the side mode for matrix A + SideMode side_mode; + + /// Describes the fill mode for matrix A + FillMode fill_mode; + + /// Describes the blas mode (symmetric/hermitian) + BlasMode blas_mode; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + /// Describes the structure of parallel reductions + SplitKMode split_k_mode; + + /// Transformation on A operand + ComplexTransform transform_A; + + /// Transformation on B operand + ComplexTransform transform_B; + + // + // Methods + // + + SymmDescription( + SymmKind symm_kind = SymmKind::kUniversal, + TensorDescription const& A = TensorDescription(), + TensorDescription const& B = TensorDescription(), + TensorDescription const& C = TensorDescription(), + SideMode side_mode = SideMode::kInvalid, + FillMode fill_mode = FillMode::kInvalid, + BlasMode blas_mode = BlasMode::kInvalid, + NumericTypeID element_epilogue = NumericTypeID::kInvalid, + SplitKMode split_k_mode = SplitKMode::kNone, + ComplexTransform transform_A = ComplexTransform::kNone, + ComplexTransform transform_B = ComplexTransform::kNone + ): + symm_kind(symm_kind), + A(A), + B(B), + C(C), + side_mode(side_mode), + fill_mode(fill_mode), + blas_mode(blas_mode), + element_epilogue(element_epilogue), + split_k_mode(split_k_mode), + transform_A(transform_A), + transform_B(transform_B) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Description of all Conv2d operations +struct ConvDescription : public OperationDescription { + /// Describes the convolution dimension support (2D or 3D) + int conv_dim; + + /// Describes the kind of convolution + ConvKind conv_kind; + + /// Describes the type of iterator algorithm (analytic or precomputed) + IteratorAlgorithmID iterator_algorithm; + + /// Describes the A operand + TensorDescription A; + + /// Describes the B operand + TensorDescription B; + + /// Describes the C operand + TensorDescription C; + + /// Describes the data type of the scalars passed to the epilogue + NumericTypeID element_epilogue; + + // + // Methods + // + // Returns Activation TensorDescription + TensorDescription activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return A; + case library::ConvKind::kDgrad : return C; + case library::ConvKind::kWgrad : return B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter TensorDescription + TensorDescription filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return B; + case library::ConvKind::kDgrad : return B; + case library::ConvKind::kWgrad : return C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output TensorDescription + TensorDescription output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return C; + case library::ConvKind::kDgrad : return A; + case library::ConvKind::kWgrad : return A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h new file mode 100644 index 0000000000000000000000000000000000000000..027944eb6ac8c6e8f250d83ed33c0899adfbd3e8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/handle.h @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief BLAS-like handle used to launch operations on the CUDA device. +*/ + +#pragma once + +#include +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Handle object +class Handle { +private: + + /// Host workspace + static int const kHostWorkspaceSize = (4 << 10); + + /// Provider of operations + Provider provider_; + + /// CUDA device properties + cudaDeviceProp device_; + + /// CUDA stream + cudaStream_t stream_; + + /// Device workspace + void *workspace_; + + /// Size of device workspace in bytes + size_t workspace_size_; + + /// Indicates whether scalars are host or device pointers + ScalarPointerMode scalar_pointer_mode_; + + /// Pointer to the most recently executed operation + Operation const *last_operation_; + + int device_idx_; + +public: + + /// Constructor + Handle(cudaStream_t stream = nullptr, size_t workspace_size = (4<<20)); + + /// Destructor + ~Handle(); + + /// Move constructor + Handle(Handle && handle); + + /// Move assignment operator + Handle &operator=(Handle && handle); + + // + // Persistent state accessors + // + + /// Returns compute capability of the selected device + int compute_capability() const; + + /// Sets the current CUDA stream + void set_stream(cudaStream_t stream); + + /// Gets the current CUDA stream + cudaStream_t get_stream() const; + + /// Gets the current provider + Provider get_provider() const; + + /// Sets the provider of operations + void set_provider(Provider provider); + + /// Gets the device workspace size + size_t get_workspace_size() const; + + /// Gets a pointer to the device workspace allocation in Global Memory + void *get_workspace() const; + + /// Sets the size of device workspace, invalidating calls to get_device_workspace() + void set_workspace_size(size_t bytes); + + /// Gets the scalar pointer mode + ScalarPointerMode get_scalar_pointer_mode() const; + + /// Sets the scalar pointer mode + void set_scalar_pointer_mode(ScalarPointerMode mode); + + /// Gets the most recently executed operation + Operation const *get_last_operation() const; + + // + // Computations + // + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C + Status gemm( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrices + + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + void * ptr_D, /// Pointer to D matrix + int64_t ldd /// Leading dimension of D matrix + ); + + /// Executes a GEMM computation: D <= alpha * A*B + beta * C. + // + // Supports batched-strided, batched array or split-K serial or split-K parallel. + // + Status gemm_universal( + + GemmUniversalMode mode, /// indicates the mode in which the kUniversal GEMM is launched + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + int cluster_m, /// cluster shape M dimension + int cluster_n, /// cluster shape N dimension + int cluster_k, /// cluster shape K dimension + int cluster_m_fallback, /// Fallback cluster shape M dimension + int cluster_n_fallback, /// Fallback cluster shape N dimension + int cluster_k_fallback, /// Fallback cluster shape K dimension + + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix - ignored for real-valued matrices + void const * ptr_A, /// Pointer to A matrix in Global Memory + int64_t lda, /// Leading dimension of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix - ignored for real-valued matrices + void const * ptr_B, /// Pointer to B matrix in Global Memory + int64_t ldb, /// Leading dimension of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C matrix + LayoutTypeID layout_C, /// Layout of D matrix + void const * ptr_C, /// Pointer to C matrix + int64_t ldc, /// Leading dimension of C matrix + + NumericTypeID element_D, /// Data type of D matrix + LayoutTypeID layout_D, /// Layout of D matrix + void * ptr_D, /// Pointer to D matrix + int64_t ldd, /// Leading dimension of D matrix + + int batch_count = 1, /// Batch count or number of split-K slices + + int64_t batch_stride_A = 0, /// Batch stride of A operand + int64_t batch_stride_B = 0, /// Batch stride of B operand + int64_t batch_stride_C = 0, /// Batch stride of C operand + int64_t batch_stride_D = 0 /// Batch stride of D operand + ); + + /// Planar complex GEMM + /// + /// Note, all data types are the real-valued base types used by the planar-complex GEMM kernel. + /// + Status gemm_planar_complex( + + int M, /// GEMM M dimension + int N, /// GEMM N dimension + int K, /// GEMM K dimension + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * ptr_A_real, /// Pointer to real part of A matrix + void const * ptr_A_imag, /// Pointer to imaginary part of A matrix + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * ptr_B_real, /// Pointer to real part of B matrix + void const * ptr_B_imag, /// Pointer to imaginary part of B matrix + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * ptr_C_real, /// Pointer to real part of C matrix + void const * ptr_C_imag, /// Pointer to imaginary part of C matrix + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * ptr_D_real, /// Pointer to real part of D matrix + void * ptr_D_imag, /// Pointer to imaginary part of D matrix + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag, /// Leading dimension of imaginary part of D matrix + + int batch_count = 1, /// Number of batched GEMMs to execute + + int64_t batch_stride_A_real = 0, + int64_t batch_stride_A_imag = 0, + + int64_t batch_stride_B_real = 0, + int64_t batch_stride_B_imag = 0, + + int64_t batch_stride_C_real = 0, + int64_t batch_stride_C_imag = 0, + + int64_t batch_stride_D_real = 0, + int64_t batch_stride_D_imag = 0 + ); + + /// Planar complex GEMM loading pointers from arrays in global memory + Status gemm_planar_complex_array( + + int expected_M, /// Expected GEMM M dimension (used for sizing CUDA grid) + int expected_N, /// Expected GEMM N dimension (used for sizing CUDA grid) + int expected_K, /// Expected GEMM K dimension + int batch_count, /// Number of independent GEMM computations to execute + + int const *M, /// Array containing the GEMM M dimension for each batch index + int const *N, /// Array containing the GEMM N dimension for each batch index + int const *K, /// Array containing the GEMM K dimension for each batch index + + NumericTypeID element_compute, /// Data type of internal accumulation + + NumericTypeID element_scalar, /// Data type of alpha/beta scalars + + void const *alpha, /// Pointer to alpha scalar + + NumericTypeID element_A, /// Data type of A matrix elements + LayoutTypeID layout_A, /// Layout of A matrix + ComplexTransform transform_A, /// Complex transformation applied to A matrix + + void const * const * ptr_A_real, /// Pointer to array containing pointers to real part of A matrices + void const * const * ptr_A_imag, /// Pointer to array containing pointers to imaginary part of A matrices + + int64_t lda_real, /// Leading dimension of real part of A matrix + int64_t lda_imag, /// Leading dimension of imaginary part of A matrix + + NumericTypeID element_B, /// Data type of B matrix elements + LayoutTypeID layout_B, /// Layout of B matrix + ComplexTransform transform_B, /// Complex transformation applied to B matrix + + void const * const * ptr_B_real, /// Pointer to array containing pointers to real part of B matrices + void const * const * ptr_B_imag, /// Pointer to array containing pointers to imaginary part of B matrices + + int64_t ldb_real, /// Leading dimension of real part of B matrix + int64_t ldb_imag, /// Leading dimension of imaginary part of B matrix + + void const * beta, /// Pointer to beta scalar + + NumericTypeID element_C, /// Data type of C and D matrix + + void const * const * ptr_C_real, /// Pointer to array containing pointers to real part of C matrices + void const * const * ptr_C_imag, /// Pointer to array containing pointers to imaginary part of C matrices + + int64_t ldc_real, /// Leading dimension of real part of C matrix + int64_t ldc_imag, /// Leading dimension of imaginary part of C matrix + + void * const * ptr_D_real, /// Pointer to array containing pointers to real part of D matrices + void * const * ptr_D_imag, /// Pointer to array containing pointers to imaginary part of D matrices + + int64_t ldd_real, /// Leading dimension of real part of D matrix + int64_t ldd_imag /// Leading dimension of imaginary part of D matrix + ); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Unique pointer storing the handle +using HandlePtr = std::unique_ptr; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds conv2d operation instances with Conv2d::ElementC = Reduction::ElementWorkspace +Operation const* find_conv_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Finds gemm operation instances with ElementC = Reduction::ElementWorkspace +Operation const* find_gemm_operation_for_parallel_reduction(Operation const *operation); +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h new file mode 100644 index 0000000000000000000000000000000000000000..6764d9a6d81286c8bba0f5184b17819bfae86978 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/library.h @@ -0,0 +1,995 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#ifndef CUTLASS_LIBRARY_LIBRARY_H +#define CUTLASS_LIBRARY_LIBRARY_H + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/types.h" +#include "cutlass/library/descriptions.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/blas3.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mode of Universal GEMM +using GemmUniversalMode = cutlass::gemm::GemmUniversalMode; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for all operations +class Operation { +public: + + virtual ~Operation() { } + + virtual OperationDescription const & description() const = 0; + + virtual Status can_implement( + void const *configuration, + void const *arguments) const = 0; + + virtual uint64_t get_host_workspace_size( + void const *configuration) const = 0; + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const = 0; + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Originally designed for metadata, but should be useful for FP8/6/4 too. + virtual Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspace_ptrs, + int problem_count, + cudaStream_t stream = nullptr) { + return Status::kErrorNotSupported; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const = 0; + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + virtual Status initialize_with_arguments(void* arguments_ptr) const { + return Status::kSuccess; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic GEMM operations +// +// OperationKind: Gemm +// GemmKind: Gemm +// +struct GemmConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Number of partitions of K dimension + int split_k_slices{0}; +}; + +/// Arguments for GEMM +struct GemmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Batched + +struct GemmBatchedConfiguration { + + /// GEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Stride between instances of the A matrix in memory + int64_t batch_stride_A{0}; + + /// Stride between instances of the B matrix in memory + int64_t batch_stride_B{0}; + + /// Stride between instances of the C matrix in memory + int64_t batch_stride_C{0}; + + /// Stride between instances of the D matrix in memory + int64_t batch_stride_D{0}; + + /// Number of GEMMs in batch + int batch_count{1}; +}; + +/// Arguments to batched GEMM +using GemmBatchedArguments = GemmArguments; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for batched GEMM in which multiple matrix products are computed +// +// OperationKind: Gemm +// GemmKind: Array + +struct GemmArrayConfiguration { + + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + int batch_count{1}; +}; + +/// Arguments for GEMM - used by all the GEMM operations +struct GemmArrayArguments { + void const * const *A{nullptr}; + void const * const *B{nullptr}; + void const * const *C{nullptr}; + void * const *D{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Universal GEMM supporting multiple split-K modes, multiple batched modes, real and complex +// +// OperationKind: Gemm +// GemmKind: Universal + +struct GemmUniversalConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int device_count{1}; +}; + +enum class Sm90MixedInputWiderOperand { + A = 0, + B = 1 +}; + +struct GemmUniversalArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{}; + library::RuntimeDatatype runtime_input_datatype_b{}; + int swizzle_size{1}; + int split_k_slices{1}; + + // For SM90 mixed input dtype kernels + bool is_sm90_mixed_dtype{false}; + Sm90MixedInputWiderOperand wider_operand{Sm90MixedInputWiderOperand::B}; + bool generate_scale_and_zero{false}; + bool generate_dequantized_AB{false}; + void *Scale{nullptr}; // Scale tensor + void *Zero{nullptr}; // Zero tensor + void *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + void *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + void *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + int device_index{0}; + + bool use_pdl{false}; +}; + +/// Block Scaled GEMM +// +// OperationKind: kBlockScaledGemm +// GemmKind: Universal + +struct BlockScaledGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + void *SFD{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + // Needed for ScaleFactor Generation + void const *norm_constant{nullptr}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + +/// Blockwise GEMM +// +// OperationKind: kBlockwiseGemm +// GemmKind: Universal + +struct BlockwiseGemmArguments { + // NOTE: these are replicated for 3.0 interfaces + gemm::GemmCoord problem_size{}; + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + int batch_count{1}; + + void const *A{nullptr}; + void const *B{nullptr}; + void const *SFA{nullptr}; + void const *SFB{nullptr}; + void const *C{nullptr}; + void *D{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + // NOTE: these are replicated for 3.0 interfaces + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + int64_t ldd{0}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + + int sf_m_vec_size{0}; + int sf_n_vec_size{0}; + int sf_k_vec_size{0}; + + // Needed for some 3.x kernels + int sm_count{0}; + library::RasterOrder raster_order{}; + int swizzle_size{1}; + int split_k_slices{1}; + + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + + bool use_pdl{false}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Complex valued GEMM in which real and imaginary parts are separated by a stride +// +// OperationKind: Gemm +// GemmKind: Planar complex + +struct GemmPlanarComplexConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArguments { + + void const *A_real{nullptr}; + void const *A_imag{nullptr}; + void const *B_real{nullptr}; + void const *B_imag{nullptr}; + void const *C_real{nullptr}; + void const *C_imag{nullptr}; + void *D_real{nullptr}; + void *D_imag{nullptr}; + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A_real{0}; + int64_t batch_stride_A_imag{0}; + int64_t batch_stride_B_real{0}; + int64_t batch_stride_B_imag{0}; + int64_t batch_stride_C_real{0}; + int64_t batch_stride_C_imag{0}; + int64_t batch_stride_D_real{0}; + int64_t batch_stride_D_imag{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This is a special form of planar complex which loads pointers and problem size +/// from memory. +struct GemmPlanarComplexArrayConfiguration { + + gemm::GemmCoord problem_size{}; + int batch_count{1}; + + int64_t lda_real{0}; + int64_t lda_imag{0}; + int64_t ldb_real{0}; + int64_t ldb_imag{0}; + int64_t ldc_real{0}; + int64_t ldc_imag{0}; + int64_t ldd_real{0}; + int64_t ldd_imag{0}; +}; + +/// Arguments for planar complex GEMMs +struct GemmPlanarComplexArrayArguments { + + int const *M{nullptr}; + int const *N{nullptr}; + int const *K{nullptr}; + + void const * const * A_real{nullptr}; + void const * const * A_imag{nullptr}; + void const * const * B_real{nullptr}; + void const * const * B_imag{nullptr}; + void const * const * C_real{nullptr}; + void const * const * C_imag{nullptr}; + void * const * D_real{nullptr}; + void * const * D_imag{nullptr}; + + void const * alpha{nullptr}; + void const * beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Grouped GEMM supporting +// +// OperationKind: Gemm +// GemmKind: Grouped + +struct GemmGroupedConfiguration { + int problem_count{0}; + // GemmGroupedConfiguration is passed to initialize(), which + // is responsible for allocating the device-side stride storage. + int64_t* lda; + int64_t* ldb; + int64_t* ldc; + + cute::Shape* problem_sizes_3x_host; +}; + +struct GemmGroupedArguments { + int problem_count{}; + gemm::GemmCoord* problem_sizes{nullptr}; + + void* ptr_A{nullptr}; + void* ptr_B{nullptr}; + void* ptr_C{nullptr}; + void* ptr_D{nullptr}; + + int64_t* lda{nullptr}; + int64_t* ldb{nullptr}; + int64_t* ldc{nullptr}; + int64_t* ldd{nullptr}; + + void const *alpha{nullptr}; + void const *beta{nullptr}; + ScalarPointerMode pointer_mode{}; + bool use_pdl{false}; + + gemm::GemmCoord cluster_shape{}; + gemm::GemmCoord cluster_shape_fallback{}; + + library::RasterOrder raster_order{}; + library::RuntimeDatatype runtime_input_datatype_a{library::RuntimeDatatype::kStatic}; + library::RuntimeDatatype runtime_input_datatype_b{library::RuntimeDatatype::kStatic}; + int swizzle_size{1}; + + // these should really be in the configuration but staying consistent with GEMM + int sm_count{0}; + int max_active_clusters{0}; + + // The user is responsible for allocating storage for problem sizes. + // Since GemmGroupedArguments is used by both the 2.x and 3.x APIs, we + // unfortunately need to have both options in this struct, and the + // underlying operation uses the one it needs. + cute::Shape* problem_sizes_3x; + cute::Shape* problem_sizes_3x_host; +}; + +struct GroupedGemmBlockScaledArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; + void* SFD{nullptr}; + void* norm_constant{nullptr}; +}; + +struct GroupedGemmBlockwiseArguments : GemmGroupedArguments { + void* SFA{nullptr}; + void* SFB{nullptr}; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// OperationKind: kSparseGemm +// + +/// Computes GEMM assuming one of the inputs has 2:4 structured sparsity. +struct SparseGemmConfiguration { + + GemmUniversalMode mode{GemmUniversalMode::kGemm}; + gemm::GemmCoord problem_size{}; + int batch_count{1}; /// number of sparse matrix products in batch + int64_t lda{0}; /// leading dimension of A operand + int64_t ldb{0}; /// leading dimension of B operand + int64_t ldc{0}; /// leading dimension of C operand + int64_t ldd{0}; /// leading dimension of D operand + int64_t lde{0}; /// leading dimension of E operand (metadata matrix) + int64_t batch_stride_A{0}; // stride between matrices + int64_t batch_stride_B{0}; // stride between matrices + int64_t batch_stride_C{0}; // stride between matrices + int64_t batch_stride_D{0}; // stride between matrices + int64_t batch_stride_E{0}; // stride between matrices +}; + +/// Arguments for sparse GEMMs +struct SparseGemmArguments { + void const *A{nullptr}; /// pointer to A matrix + void const *B{nullptr}; /// pointer to B matrix + void const *C{nullptr}; /// pointer to C matrix + void *D{nullptr}; /// pointer to D matrix + void const *E{nullptr}; /// pointer to E matrix (metadata) + void const *alpha{nullptr}; /// pointer to alpha scalar + void const *beta{nullptr}; /// pointer to beta scalar + ScalarPointerMode pointer_mode{}; /// enumerant indicating whether alpha/beta pointers are host + /// or device pointers. + bool use_pdl{false}; /// Whether to use PDL when launching the kernel +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic Rank K update operations +// +// OperationKind: (Syrk, Herk, Syr2k, Her2k) +// RankKKind: Universal +// +struct RankKConfiguration { + + /// SYRK problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Syrk, Herk, Syr2k, Her2k) +struct RankKArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix (used only for Syr2k and Her2k) + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic TRMM operations +// +// OperationKind: Trmm +// TrmmKind: Universal +// +struct TrmmConfiguration { + + /// TRMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for TRMM +struct TrmmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for basic SYMM/HEMM update operations +// +// OperationKind: (Symm, Hemm) +// SymmKind: Universal +// +struct SymmConfiguration { + + /// SYMM/HEMM problem size + gemm::GemmCoord problem_size{}; + + /// Leading dimension of A matrix + int64_t lda{0}; + + /// Leading dimension of B matrix + int64_t ldb{0}; + + /// Leading dimension of C matrix + int64_t ldc{0}; + + /// Leading dimension of D matrix + int64_t ldd{0}; + + /// Batch Count + int batch_count{1}; +}; + +/// Arguments for (Symm, Hemm) +struct SymmArguments { + + /// Pointer to A matrix + void const *A{nullptr}; + + /// Pointer to B matrix + void const *B{nullptr}; + + /// Pointer to C matrix + void const *C{nullptr}; + + /// Pointer to D matrix + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + int64_t batch_stride_A{0}; + int64_t batch_stride_B{0}; + int64_t batch_stride_C{0}; + int64_t batch_stride_D{0}; + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Two dimensional convolution +// +// OperationKind: Conv2d +// +struct Conv2dConfiguration { + + conv::SplitKMode split_k_mode; + + /// Conv2d problem size + // contains strictly conv2d size (N,H,W,C,K,R,S,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv2dProblemSize problem_size{}; + + // stride of operand A + std::vector stride_a{}; + + // stride of operand B + std::vector stride_b{}; + + // stride of operand C + std::vector stride_c{}; +}; + + +/// Three dimensional convolution +// +// OperationKind: Conv3d +// +struct Conv3dConfiguration { + + conv::SplitKMode split_k_mode{}; + + /// Conv2d problem size + // contains strictly conv2d size (N,D,H,W,C,K,T,R,S,Z,P,Q,padding,stride,dilation,mode) + // also includes (split_k_slices, groups) + conv::Conv3dProblemSize problem_size{}; + + /// Layout object for activations tensor + layout::TensorNDHWC layout_activations{}; + + /// Layout object for filters tensor + layout::TensorNDHWC layout_filters{}; + + /// Layout object for source tensor + layout::TensorNDHWC layout_source{}; + + /// Layout object for output tensor + layout::TensorNDHWC layout_output{}; + + // + // Methods + // + + // Mapping functions (A,B,C -> activation,filter,output) + layout::TensorNDHWC layout_a(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_activations; + case library::ConvKind::kDgrad: return layout_output; + case library::ConvKind::kWgrad: return layout_output; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_b(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_filters; + case library::ConvKind::kDgrad: return layout_filters; + case library::ConvKind::kWgrad: return layout_activations; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + layout::TensorNDHWC layout_c(library::ConvKind const &conv_kind) const { + switch (conv_kind) { + case library::ConvKind::kFprop: return layout_output; + case library::ConvKind::kDgrad: return layout_activations; + case library::ConvKind::kWgrad: return layout_filters; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +/// Arguments for CONV +struct ConvArguments { + + ///////////////////////////////////////////////////////// + /// ImplicitGemm matrices A, B, C, D + ///////////////////////////////////////////////////////// + /// pointer to implicit gemm matrix A + void const *A{nullptr}; + + /// pointer to implicit gemm matrix B + void const *B{nullptr}; + + /// pointer to reordered matrix B + void const *reordered_B{nullptr}; + + /// pointer to implicit gemm matrix C + void const *C{nullptr}; + + /// pointer to implicit gemm destination matrix D + void *D{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Configuration for Reduction operations +// +// OperationKind: Reduction +// +struct ReductionConfiguration { + + /// Reduction problem size + MatrixCoord problem_size{}; + + /// Number of partitions to reduce + int partitions{0}; + + /// Number of elements between each partition + int64_t partition_stride{0}; + + /// leading dimension of 'w'orkspace operand + int64_t ldw{0}; + + /// leading dimension of 's'ource operand + int64_t lds{0}; + + /// leading dimension of 'd'estination operand + int64_t ldd{0}; +}; + +/// Arguments for Reduction +struct ReductionArguments { + + /// Pointer to workspace matrix + void const *workspace{nullptr}; + + /// Pointer to source matrix + void const *source{nullptr}; + + /// Pointer to destination matrix + void *destination{nullptr}; + + /// pointer to reference matrix + void *reference{nullptr}; + + /// Host or device pointer to alpha scalar + void const *alpha{nullptr}; + + /// Host or device pointer to beta scalar + void const *beta{nullptr}; + + /// Enumerant indicating whether alpha/beta point to host or device memory + ScalarPointerMode pointer_mode{}; + + /// Whether to use PDL when launching the kernel + bool use_pdl{false}; +}; + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h new file mode 100644 index 0000000000000000000000000000000000000000..c4fb0ee8ca32124450b1063cc3613078e600479d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/manifest.h @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Manifest of CUTLASS Library + + This is the root of the data structure containing CUTLASS objects +*/ + +#pragma once + +#include +#include +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "library.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// Forward declaration +class Manifest; + +// init and insert all cutlass gemm operations in manifest object (procedurally generated using generator.py) +void initialize_all(Manifest &manifest); + +// init and insert all reduction op in manifest object (manually instantiated in library/reduction) +void initialize_all_reduction_op(Manifest &manifest); + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +/// List of operations +using OperationVector = std::vector>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manifest of CUTLASS Library +class Manifest { +private: + + /// Operation provider + Provider provider_; + + /// Global list of operations + OperationVector operations_; + +public: + Manifest (Provider provider = library::Provider::kCUTLASS) : provider_(provider) { } + + /// Top-level initialization + Status initialize(); + + /// Used for initialization + void reserve(size_t operation_count); + + /// Graceful shutdown + Status release(); + + /// Appends an operation and takes ownership + void append(Operation *operation_ptr) {\ + // This function is inline s.t. it is present in generated libraries + // without having to compile or link in manifest.cpp + operations_.emplace_back(operation_ptr); + } + + /// Returns an iterator to the first operation + OperationVector const &operations() const; + + /// Returns a const iterator + OperationVector::const_iterator begin() const; + + /// Returns a const iterator + OperationVector::const_iterator end() const; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h new file mode 100644 index 0000000000000000000000000000000000000000..f36232c8dc833e2b24d681686f6662e79b7ecd0a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/operation_table.h @@ -0,0 +1,905 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* + \file + \brief Defines a data structure in which a set of functionally equivalent library::Operation + instances may be queried. +*/ + +#pragma once +#include +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct GemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + ComplexTransform transform_A; + NumericTypeID element_B; + LayoutTypeID layout_B; + ComplexTransform transform_B; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + + // + // Methods + // + + inline + GemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + ComplexTransform transform_A = ComplexTransform::kNone, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + ComplexTransform transform_B = ComplexTransform::kNone, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor + ): + provider(provider), + gemm_kind(gemm_kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + transform_A(transform_A), + element_B(element_B), + layout_B(layout_B), + transform_B(transform_B), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D) + { } + + inline + bool operator==(GemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (transform_A == rhs.transform_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (transform_B == rhs.transform_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D); + } + + inline + bool operator!=(GemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " transform_A: " << to_string(k.transform_A) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " transform_B: " << to_string(k.transform_B) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for GemmFunctionalKey +struct GemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(GemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.element_compute)), 3) ^ + rotl(hash(int(key.element_scalar)), 4) ^ + rotl(hash(int(key.element_A)), 5) ^ + rotl(hash(int(key.layout_A)), 6) ^ + rotl(hash(int(key.transform_A)), 7) ^ + rotl(hash(int(key.element_B)), 8) ^ + rotl(hash(int(key.layout_B)), 9) ^ + rotl(hash(int(key.transform_B)), 10) ^ + rotl(hash(int(key.element_C)), 11) ^ + rotl(hash(int(key.layout_C)), 12) ^ + rotl(hash(int(key.element_D)), 13) ^ + rotl(hash(int(key.layout_D)), 14); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for GEMM operators +struct GemmPreferenceKey { + + int compute_capability; + int alignment; + + // + // Methods + // + + GemmPreferenceKey(): compute_capability(), alignment() { } + + GemmPreferenceKey(int cc, int alignment): compute_capability(cc), alignment(alignment) { } + + bool operator<(GemmPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (alignment < rhs.alignment)); + } + + bool operator==(GemmPreferenceKey const &rhs) const { + return compute_capability == rhs.compute_capability; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { + out << "{\n" + << "compute_capability : " << key.compute_capability << std::endl + << "alignment : " << key.alignment << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using GemmOperationVectorMap = std::map< + GemmPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using GemmOperationFunctionalMap = std::unordered_map< + GemmFunctionalKey, + GemmOperationVectorMap, + GemmFunctionalKeyHasher +>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for BlockScaled Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockScaledGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + NumericTypeID element_SFD; + LayoutTypeID layout_SFD; + int SFVecSize; + int EpilogueSFVecSize; + // + // Methods + // + + inline + BlockScaledGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockScaledGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFD = NumericTypeID::kF16, + LayoutTypeID layout_SFD = LayoutTypeID::kRowMajor, + int sf_vec_size = 32 + , int epilogue_sf_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + element_SFD(element_SFD), + layout_SFD(layout_SFD), + SFVecSize(sf_vec_size) + , EpilogueSFVecSize(epilogue_sf_vec_size) + { } + + inline + bool operator==(BlockScaledGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (element_SFD == rhs.element_SFD) && + (layout_SFD == rhs.layout_SFD) && + (SFVecSize == rhs.SFVecSize) + && (EpilogueSFVecSize == rhs.EpilogueSFVecSize) + ; + } + + inline + bool operator!=(BlockScaledGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockScaledGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " element_SFD: " << to_string(k.element_SFD) << "\n" + << " layout_SFD: " << to_string(k.layout_SFD) << "\n" + << " SFVecSize: " << k.SFVecSize << "\n" + << "EpilogueSFVecSize: " << k.EpilogueSFVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockScaledGemmFunctionalKeyHasher +struct BlockScaledGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockScaledGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.element_SFD)), 16) ^ + rotl(hash(int(key.layout_SFD)), 17) ^ + rotl(hash(int(key.SFVecSize)), 18) ^ + rotl(hash(int(key.EpilogueSFVecSize)), 19) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockScaledGemmOperationFunctionalMap = std::unordered_map< + BlockScaledGemmFunctionalKey, + GemmOperationVectorMap, + BlockScaledGemmFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Blockwise Gemm Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying Gemm functional behavior +struct BlockwiseGemmFunctionalKey { + + Provider provider; + GemmKind gemm_kind; + OperationKind kind; + NumericTypeID element_compute; + NumericTypeID element_scalar; + NumericTypeID element_A; + LayoutTypeID layout_A; + NumericTypeID element_SFA; + NumericTypeID element_B; + LayoutTypeID layout_B; + NumericTypeID element_SFB; + NumericTypeID element_C; + LayoutTypeID layout_C; + NumericTypeID element_D; + LayoutTypeID layout_D; + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; + // + // Methods + // + + inline + BlockwiseGemmFunctionalKey( + Provider provider, + GemmKind gemm_kind = GemmKind::kGemm, + OperationKind kind = OperationKind::kBlockwiseGemm, + NumericTypeID element_compute = NumericTypeID::kF32, + NumericTypeID element_scalar = NumericTypeID::kF32, + NumericTypeID element_A = NumericTypeID::kF16, + LayoutTypeID layout_A = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFA = NumericTypeID::kF16, + NumericTypeID element_B = NumericTypeID::kF16, + LayoutTypeID layout_B = LayoutTypeID::kColumnMajor, + NumericTypeID element_SFB = NumericTypeID::kF16, + NumericTypeID element_C = NumericTypeID::kF16, + LayoutTypeID layout_C = LayoutTypeID::kColumnMajor, + NumericTypeID element_D = NumericTypeID::kF16, + LayoutTypeID layout_D = LayoutTypeID::kColumnMajor, + int sfm_vec_size = 32, + int sfn_vec_size = 32, + int sfk_vec_size = 32 + ): + provider(provider), + gemm_kind(gemm_kind), + kind(kind), + element_compute(element_compute), + element_scalar(element_scalar), + element_A(element_A), + layout_A(layout_A), + element_SFA(element_SFA), + element_B(element_B), + layout_B(layout_B), + element_SFB(element_SFB), + element_C(element_C), + layout_C(layout_C), + element_D(element_D), + layout_D(layout_D), + SFMVecSize(sfm_vec_size), + SFNVecSize(sfn_vec_size), + SFKVecSize(sfk_vec_size) + { } + + inline + bool operator==(BlockwiseGemmFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (gemm_kind == rhs.gemm_kind) && + (kind == rhs.kind) && + (element_compute == rhs.element_compute) && + (element_scalar == rhs.element_scalar) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_SFA == rhs.element_SFA) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_SFB == rhs.element_SFB) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_D == rhs.element_D) && + (layout_D == rhs.layout_D) && + (SFMVecSize == rhs.SFMVecSize) && + (SFNVecSize == rhs.SFNVecSize) && + (SFKVecSize == rhs.SFKVecSize); + } + + inline + bool operator!=(BlockwiseGemmFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream & operator<<(std::ostream &out, cutlass::library::BlockwiseGemmFunctionalKey const &k) { + + out << "{\n" + << " provider: " << to_string(k.provider) << "\n" + << " gemm_kind: " << to_string(k.gemm_kind) << "\n" + << " kind: " << to_string(k.kind) << "\n" + << " element_compute: " << to_string(k.element_compute) << "\n" + << " element_scalar: " << to_string(k.element_scalar) << "\n" + << " element_A: " << to_string(k.element_A) << "\n" + << " layout_A: " << to_string(k.layout_A) << "\n" + << " element_SFA: " << to_string(k.element_SFA) << "\n" + << " element_B: " << to_string(k.element_B) << "\n" + << " layout_B: " << to_string(k.layout_B) << "\n" + << " element_SFB: " << to_string(k.element_SFB) << "\n" + << " element_C: " << to_string(k.element_C) << "\n" + << " layout_C: " << to_string(k.layout_C) << "\n" + << " element_D: " << to_string(k.element_D) << "\n" + << " layout_D: " << to_string(k.layout_D) << "\n" + << " SFMVecSize: " << k.SFMVecSize << "\n" + << " SFNVecSize: " << k.SFNVecSize << "\n" + << " SFKVecSize: " << k.SFKVecSize << "\n" + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for BlockwiseGemmFunctionalKeyHasher +struct BlockwiseGemmFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(BlockwiseGemmFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.gemm_kind)), 2) ^ + rotl(hash(int(key.kind)), 3) ^ + rotl(hash(int(key.element_compute)), 4) ^ + rotl(hash(int(key.element_scalar)), 5) ^ + rotl(hash(int(key.element_A)), 6) ^ + rotl(hash(int(key.layout_A)), 7) ^ + rotl(hash(int(key.element_SFA)), 8) ^ + rotl(hash(int(key.element_B)), 9) ^ + rotl(hash(int(key.layout_B)), 10) ^ + rotl(hash(int(key.element_SFB)), 11) ^ + rotl(hash(int(key.element_C)), 12) ^ + rotl(hash(int(key.layout_C)), 13) ^ + rotl(hash(int(key.element_D)), 14) ^ + rotl(hash(int(key.layout_D)), 15) ^ + rotl(hash(int(key.SFMVecSize)), 16) ^ + rotl(hash(int(key.SFNVecSize)), 17) ^ + rotl(hash(int(key.SFKVecSize)), 18) + ; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using BlockwiseGemmOperationFunctionalMap = std::unordered_map< + BlockwiseGemmFunctionalKey, + GemmOperationVectorMap, + BlockwiseGemmFunctionalKeyHasher +>; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Data Structures for Conv Functional Maps +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tuple uniquely identifying conv2d functional behavior +struct ConvFunctionalKey { + library::Provider provider; + library::ConvKind conv_kind; + library::NumericTypeID element_A; + library::LayoutTypeID layout_A; + library::NumericTypeID element_B; + library::LayoutTypeID layout_B; + library::NumericTypeID element_C; + library::LayoutTypeID layout_C; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_compute; + + + // + // Methods + // + + inline + ConvFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::ConvKind conv_kind = library::ConvKind::kFprop, + library::NumericTypeID element_A = library::NumericTypeID::kF16, + library::LayoutTypeID layout_A = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_B = library::NumericTypeID::kF16, + library::LayoutTypeID layout_B = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_C = library::NumericTypeID::kF16, + library::LayoutTypeID layout_C = library::LayoutTypeID::kTensorNHWC, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_compute = library::NumericTypeID::kF32 + ): + provider(provider), + conv_kind(conv_kind), + element_A(element_A), + layout_A(layout_A), + element_B(element_B), + layout_B(layout_B), + element_C(element_C), + layout_C(layout_C), + element_accumulator(element_accumulator), + element_compute(element_compute) + { } + + inline + bool operator==(ConvFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (conv_kind == rhs.conv_kind) && + (element_A == rhs.element_A) && + (layout_A == rhs.layout_A) && + (element_B == rhs.element_B) && + (layout_B == rhs.layout_B) && + (element_C == rhs.element_C) && + (layout_C == rhs.layout_C) && + (element_accumulator == rhs.element_accumulator) && + (element_compute == rhs.element_compute); + } + + inline + bool operator!=(ConvFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::ConvFunctionalKey& key) { + out << "{\n" + << "provider: " << to_string(key.provider) << std::endl + << "conv_kind: " << to_string(key.conv_kind) << std::endl + << "element_A: " << to_string(key.element_A) << std::endl + << "layout_A: " << to_string(key.layout_A) << std::endl + << "element_B: " << to_string(key.element_B) << std::endl + << "layout_B: " << to_string(key.layout_B) << std::endl + << "element_C: " << to_string(key.element_C) << std::endl + << "layout_C: " << to_string(key.layout_C) << std::endl + << "element_accumulator: " << to_string(key.element_accumulator) << std::endl + << "element_compute: " << to_string(key.element_compute) << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +struct ConvFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ConvFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.conv_kind)), 2) ^ + rotl(hash(int(key.element_A)), 3) ^ + rotl(hash(int(key.layout_A)), 4) ^ + rotl(hash(int(key.element_B)), 5) ^ + rotl(hash(int(key.layout_B)), 6) ^ + rotl(hash(int(key.element_C)), 7) ^ + rotl(hash(int(key.layout_C)), 8) ^ + rotl(hash(int(key.element_accumulator)), 9) ^ + rotl(hash(int(key.element_compute)), 10); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes a partial ordering to search for Conv2d operators +struct ConvPreferenceKey { + + int compute_capability; + IteratorAlgorithmID iterator_algorithm; + + + // + // Methods + // + + ConvPreferenceKey(): compute_capability(), iterator_algorithm() { } + + ConvPreferenceKey(int cc, IteratorAlgorithmID iterator_algorithm): + compute_capability(cc), iterator_algorithm(iterator_algorithm) { } + + bool operator<(ConvPreferenceKey const &rhs) const { + return (compute_capability < rhs.compute_capability) || + ((compute_capability == rhs.compute_capability) && (iterator_algorithm < rhs.iterator_algorithm)); + } + + bool operator==(ConvPreferenceKey const &rhs) const { + return (compute_capability == rhs.compute_capability) && + (iterator_algorithm == rhs.iterator_algorithm); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Maps minimum compute capability onto a vector of possible operations +using ConvOperationVectorMap = std::map< + ConvPreferenceKey, + std::vector +>; + +/// Maps a GemmFunctionalKey onto a vector of Operation * objects expected to be of kind kGemm +using ConvOperationFunctionalMap = std::unordered_map< + ConvFunctionalKey, + ConvOperationVectorMap, + ConvFunctionalKeyHasher +>; +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Tuple uniquely identifying conv2d functional behavior +struct ReductionFunctionalKey { + library::Provider provider; + library::NumericTypeID element_workspace; + library::NumericTypeID element_accumulator; + library::NumericTypeID element_output; + library::NumericTypeID element_compute; + library::MathOperationID reduce_math_op; + library::EpilogueKind epilogue_math_op; + + + // + // Methods + // + + inline + ReductionFunctionalKey( + library::Provider provider = library::Provider::kInvalid, + library::NumericTypeID element_workspace = library::NumericTypeID::kF16, + library::NumericTypeID element_accumulator = library::NumericTypeID::kF32, + library::NumericTypeID element_output = library::NumericTypeID::kF16, + library::NumericTypeID element_compute = library::NumericTypeID::kF32, + library::MathOperationID reduce_math_op = library::MathOperationID::kAdd, + library::EpilogueKind epilogue_math_op = library::EpilogueKind::kLinearCombination + ): + provider(provider), + element_workspace(element_workspace), + element_accumulator(element_accumulator), + element_output(element_output), + element_compute(element_compute), + reduce_math_op(reduce_math_op), + epilogue_math_op(epilogue_math_op) + { } + + inline + bool operator==(ReductionFunctionalKey const &rhs) const { + return + (provider == rhs.provider) && + (element_workspace == rhs.element_workspace) && + (element_accumulator == rhs.element_accumulator) && + (element_output == rhs.element_output) && + (element_compute == rhs.element_compute) && + (reduce_math_op == rhs.reduce_math_op) && + (epilogue_math_op == rhs.epilogue_math_op); + } + + inline + bool operator!=(ReductionFunctionalKey const &rhs) const { + return !(*this == rhs); + } +}; + + +struct ReductionFunctionalKeyHasher { + using IntHash = std::hash; + + inline + static size_t rotl(size_t key, int shl) { + return (key << shl) | (key >> (sizeof(key)*8u - static_cast(shl))); + } + + inline + size_t operator()(ReductionFunctionalKey const &key) const { + IntHash hash; + + return + rotl(hash(int(key.provider)), 1) ^ + rotl(hash(int(key.element_workspace)), 2) ^ + rotl(hash(int(key.element_accumulator)), 3) ^ + rotl(hash(int(key.element_output)), 4) ^ + rotl(hash(int(key.element_compute)), 5) ^ + rotl(hash(int(key.reduce_math_op)), 6) ^ + rotl(hash(int(key.epilogue_math_op)), 7); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +inline +std::ostream& operator<< (std::ostream& out, const ReductionFunctionalKey& key) { + out << "{\n" + << "provider: " << library::to_string(key.provider) << std::endl + << "element_workspace : " << library::to_string(key.element_workspace) << std::endl + << "element_accumulator : " << library::to_string(key.element_accumulator) << std::endl + << "element_output : " << library::to_string(key.element_output) << std::endl + << "element_compute : " << library::to_string(key.element_compute) << std::endl + << "}"; + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// ReductionOperationFunctionalMap has NO preference key and a single instance per functional key +// i.e. only one tile size configuration per functional key +using ReductionOperationFunctionalMap = std::unordered_map< + ReductionFunctionalKey, + library::Operation const *, + ReductionFunctionalKeyHasher +>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Table of cutlass::library::Operation instances +class OperationTable { +public: + + /// Map of all operations of type kGemm + // provider (kCUTLASS) + GemmOperationFunctionalMap gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockScaledGemmOperationFunctionalMap block_scaled_gemm_operations; + + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + BlockwiseGemmOperationFunctionalMap blockwise_gemm_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv2d_operations; + + /// Map of all operations of type kConv3d + // provider (kCUTLASS, kReferenceHost, kReferenceDevice) + ConvOperationFunctionalMap conv3d_operations; + + /// Map of all operations of type kConv2d + // provider (kCUTLASS) + ReductionOperationFunctionalMap reduction_operations; + +public: + + void append(Manifest const &manifest); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +std::ostream & operator<<(std::ostream &out, cutlass::library::GemmFunctionalKey const &k); diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h new file mode 100644 index 0000000000000000000000000000000000000000..9a757433f38fbf10d9a352e07c7f3084a99e4098 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/singleton.h @@ -0,0 +1,68 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/operation_table.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Singleton instance stores a Manifest and Operation table +class Singleton { +public: + + /// Manifest object + Manifest manifest; + + /// Operation table referencing the Manifest + OperationTable operation_table; + +public: + + Singleton(); + + static Singleton const &get(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h new file mode 100644 index 0000000000000000000000000000000000000000..9f8c4ff13ba543b4ec63997ba55e9278bfb357a6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/types.h @@ -0,0 +1,295 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + #pragma once + + ///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Layout type identifier +enum class LayoutTypeID { + kUnknown, + kColumnMajor, + kRowMajor, + kBlockScalingTensor, + kColumnMajorInterleavedK2, + kRowMajorInterleavedK2, + kColumnMajorInterleavedK4, + kRowMajorInterleavedK4, + kColumnMajorInterleavedK16, + kRowMajorInterleavedK16, + kColumnMajorInterleavedK32, + kRowMajorInterleavedK32, + kColumnMajorInterleavedK64, + kRowMajorInterleavedK64, + kTensorNCHW, + kTensorNCDHW, + kTensorNHWC, + kTensorNDHWC, + kTensorNC32HW32, + kTensorC32RSK32, + kTensorNC64HW64, + kTensorC64RSK64, + kInvalid +}; + +/// Numeric data type +enum class NumericTypeID { + kUnknown, + kVoid, + kB1, + kU2, + kU4, + kU8, + kU16, + kU32, + kU64, + kS2, + kS4, + kS8, + kS16, + kS32, + kS64, + kFE4M3, + kFE5M2, + + kFE2M3, + kFE3M2, + kFE2M1, + kFUE8M0, + kFUE4M3, + kF8, + kF6, + kF4, + + kF16, + kBF16, + kTF32, + kF32, + kF64, + kCF16, + kCBF16, + kCF32, + kCTF32, + kCF64, + kCS2, + kCS4, + kCS8, + kCS16, + kCS32, + kCS64, + kCU2, + kCU4, + kCU8, + kCU16, + kCU32, + kCU64, + kInvalid +}; + +/// Enumerated type describing a transformation on a complex value. +enum class ComplexTransform { + kNone, + kConjugate, + kInvalid +}; + +/// Providers +enum class Provider { + kNone, + kCUTLASS, + kReferenceHost, + kReferenceDevice, + kCUBLAS, + kCUDNN, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating the kind of operation +enum class OperationKind { + kGemm, + kBlockScaledGemm, + kBlockwiseGemm, + kRankK, + kRank2K, + kTrmm, + kSymm, + kConv2d, + kConv3d, + kEqGemm, + kSparseGemm, + kReduction, + kGroupedGemm, + kInvalid +}; + +/// Enumeration indicating whether scalars are in host or device memory +enum class ScalarPointerMode { + kHost, + kDevice, + kInvalid +}; + +/// Describes how reductions are performed across threadblocks +enum class SplitKMode { + kNone, + kSerial, + kParallel, + kParallelSerial, + kInvalid +}; + +/// Indicates the classificaition of the math instruction +enum class OpcodeClassID { + kSimt, + kTensorOp, + kWmmaTensorOp, + kSparseTensorOp, + kBlockScaledOp, + kInvalid +}; + +enum class MathOperationID { + kAdd, + kMultiplyAdd, + kMultiplyAddSaturate, + kMultiplyAddMixedInputUpcast, + kMultiplyAddFastBF16, + kMultiplyAddFastF16, + kMultiplyAddFastF32, + kMultiplyAddComplex, + kMultiplyAddComplexFastF32, + kMultiplyAddGaussianComplex, + kXorPopc, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumeration indicating what kind of GEMM operation to perform +enum class GemmKind { + kGemm, + kBlockScaledGemm, + kSparse, + kUniversal, + kPlanarComplex, + kPlanarComplexArray, + kGrouped, + kInvalid +}; + +/// Enumeration indicating what kind of RankK update operation to perform +enum class RankKKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of TRMM operation to perform +enum class TrmmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of SYMM/HEMM operation to perform +enum class SymmKind { + kUniversal, + kInvalid +}; + +/// Enumeration indicating what kind of Conv2d operation to perform +enum class ConvKind { + kUnknown, + kFprop, + kDgrad, + kWgrad, + kInvalid +}; + +enum class ConvModeID { + kCrossCorrelation, + kConvolution, + kInvalid +}; + +// Iterator algorithm enum in order of general performance-efficiency +enum class IteratorAlgorithmID { + kNone, + kAnalytic, + kOptimized, + kFixedChannels, + kFewChannels, + kInvalid +}; + + +enum class EpilogueKind { + kUnknown, + kConversion, + kLinearCombination, + kLinearCombinationClamp, + kLinearCombinationPlanarComplex, + kLinearCombinationRelu, + kLinearCombinationSigmoid, + kInvalid +}; + + +enum class RuntimeDatatype { + kStatic, + kE4M3, + kE5M2, + kE3M2, + kE2M3, + kE2M1, + + kInvalid +}; + + +enum class RasterOrder { + kAlongN, + kAlongM, + kHeuristic, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h new file mode 100644 index 0000000000000000000000000000000000000000..f537421751c1f2af3b95a2e1951006af441b28e0 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/include/cutlass/library/util.h @@ -0,0 +1,281 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief Utilities accompanying the CUTLASS library for interacting with Library types. +*/ + +#ifndef CUTLASS_LIBRARY_UTIL_H +#define CUTLASS_LIBRARY_UTIL_H + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexical cast from string +template T from_string(std::string const &); + +/// Converts a Provider enumerant to a string +char const *to_string(Provider provider, bool pretty = false); + +/// Parses a Provider enumerant from a string +template <> Provider from_string(std::string const &str); + +/// Converts a GemmKind enumerant to a string +char const *to_string(GemmKind type, bool pretty = false); + +/// Converts a RankKKind enumerant to a string +char const *to_string(RankKKind type, bool pretty = false); + +/// Converts a TrmmKind enumerant to a string +char const *to_string(TrmmKind type, bool pretty = false); + +/// Converts a SymmKind enumerant to a string +char const *to_string(SymmKind type, bool pretty = false); + +/// Converts a SideMode enumerant to a string +char const *to_string(SideMode type, bool pretty = false); + +/// Converts a FillMode enumerant to a string +char const *to_string(FillMode type, bool pretty = false); + +/// Converts a BlasMode enumerant to a string +char const *to_string(BlasMode type, bool pretty = false); + +/// Converts a DiagType enumerant to a string +char const *to_string(DiagType type, bool pretty = false); + +/// Converts a NumericType enumerant to a string +char const *to_string(OperationKind type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> OperationKind from_string(std::string const &str); + +/// Converts a NumericType enumerant to a string +char const *to_string(NumericTypeID type, bool pretty = false); + +/// Parses a NumericType enumerant from a string +template <> NumericTypeID from_string(std::string const &str); + +/// Returns the size of a data type in bits +int sizeof_bits(NumericTypeID type); + +/// Returns true if the numeric type is a complex data type or false if real-valued. +bool is_complex_type(NumericTypeID type); + +/// Returns the real-valued type underlying a type (only different from 'type' if complex) +NumericTypeID get_real_type(NumericTypeID type); + +/// Returns true if numeric type is integer +bool is_integer_type(NumericTypeID type); + +/// Returns true if numeric type is signed +bool is_signed_type(NumericTypeID type); + +/// Returns true if numeric type is a signed integer +bool is_signed_integer(NumericTypeID type); + +/// returns true if numeric type is an unsigned integer +bool is_unsigned_integer(NumericTypeID type); + +/// Returns true if numeric type is floating-point type +bool is_float_type(NumericTypeID type); + +/// To string method for cutlass::Status +char const *to_string(Status status, bool pretty = false); + +/// Converts a LayoutTypeID enumerant to a string +char const *to_string(LayoutTypeID layout, bool pretty = false); + +/// Parses a LayoutType enumerant from a string +template <> LayoutTypeID from_string(std::string const &str); + +/// Returns the rank of a layout's stride base on the LayoutTypeID +int get_layout_stride_rank(LayoutTypeID layout_id); + +/// Converts a OpcodeClassID enumerant to a string +char const *to_string(OpcodeClassID type, bool pretty = false); + +/// Converts a OpcodeClassID enumerant from a string +template <> +OpcodeClassID from_string(std::string const &str); + +/// Converts a ComplexTransform enumerant to a string +char const *to_string(ComplexTransform type, bool pretty = false); + +/// Converts a ComplexTransform enumerant from a string +template <> +ComplexTransform from_string(std::string const &str); + + +/// Converts a SplitKMode enumerant to a string +char const *to_string(SplitKMode split_k_mode, bool pretty = false); + +/// Converts a SplitKMode enumerant from a string +template <> +SplitKMode from_string(std::string const &str); + +/// Converts a ConvModeID enumerant to a string +char const *to_string(ConvModeID type, bool pretty = false); + +/// Converts a ConvModeID enumerant from a string +template <> +ConvModeID from_string(std::string const &str); + +/// Converts a IteratorAlgorithmID enumerant to a string +char const *to_string(IteratorAlgorithmID type, bool pretty = false); + +/// Converts a IteratorAlgorithmID enumerant from a string +template <> +IteratorAlgorithmID from_string(std::string const &str); + +/// Converts a ConvKind enumerant to a string +char const *to_string(ConvKind type, bool pretty = false); + +/// Converts a ConvKind enumerant from a string +template <> +ConvKind from_string(std::string const &str); + + +/// Converts a RuntimeDatatype enumerant to a string +char const *to_string(cutlass::library::RuntimeDatatype type, bool pretty = false); + +/// Convers a RuntimeDatatype enumerant from a string +template<> +cutlass::library::RuntimeDatatype from_string(std::string const &str); + + +/// Converts a RasterOrder enumerant to a string +char const *to_string(RasterOrder type, bool pretty = false); + +/// Convers a RasterOrder enumerant from a string +template<> +RasterOrder from_string(std::string const &str); + +/// Converts a bool to a string +char const *to_string(bool type, bool pretty = false); + +/// Convers a bool from a string +template<> +bool from_string(std::string const &str); + +/// Lexical cast from int64_t to string +std::string lexical_cast(int64_t int_value); + +/// Lexical cast a string to a byte array. Returns true if cast is successful or false if invalid. +bool lexical_cast(std::vector &bytes, NumericTypeID type, std::string const &str); + +/// Lexical cast TO a string FROM a byte array. Returns true if cast is successful or false if invalid. +std::string lexical_cast(std::vector &bytes, NumericTypeID type); + +/// Casts from a signed int64 to the destination type. Returns true if successful. +bool cast_from_int64(std::vector &bytes, NumericTypeID type, int64_t src); + +/// Casts from an unsigned int64 to the destination type. Returns true if successful. +bool cast_from_uint64(std::vector &bytes, NumericTypeID type, uint64_t src); + +/// Casts from a real value represented as a double to the destination type. Returns true if successful. +bool cast_from_double(std::vector &bytes, NumericTypeID type, double src); + +NumericTypeID dynamic_datatype_to_id(RuntimeDatatype type); + +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = (call); \ + if (err != cudaSuccess) { \ + std::cerr << "CUDA Error: " << cudaGetErrorString(err) << " in " << __func__ << " at " \ + << __FILE__ << ":" << __LINE__ << std::endl; \ + return Status::kInvalid; \ + } \ + } while (0) + +// RAII CUDA buffer container +class CudaBuffer { +public: + CudaBuffer() : size_(0), d_ptr_(nullptr) {} + + explicit CudaBuffer(size_t size) : size_(size), d_ptr_(nullptr) { + cudaError_t err = cudaMalloc(&d_ptr_, size_); + if (err != cudaSuccess) { + throw std::runtime_error("cudaMalloc failed: " + std::string(cudaGetErrorString(err))); + } + } + + ~CudaBuffer() { + if (d_ptr_) { + cudaFree(d_ptr_); + } + } + + CudaBuffer(CudaBuffer const&) = delete; + CudaBuffer& operator=(CudaBuffer const&) = delete; + + CudaBuffer(CudaBuffer&& other) noexcept : size_(other.size_), d_ptr_(other.d_ptr_) { + other.d_ptr_ = nullptr; + other.size_ = 0; + } + + CudaBuffer& operator=(CudaBuffer&& other) noexcept { + if (this != &other) { + if (d_ptr_) { + cudaFree(d_ptr_); + } + d_ptr_ = other.d_ptr_; + size_ = other.size_; + other.d_ptr_ = nullptr; + other.size_ = 0; + } + return *this; + } + + void* data() const noexcept { return d_ptr_; } + size_t size() const noexcept { return size_; } + +private: + size_t size_; + void* d_ptr_; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#endif diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c96b9a2212b42c191551ea70da3ac3baecbed487 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/block_scaled_gemm_operation_3x.hpp @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockScaledGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using Sm1xxBlkScaledConfig = typename CollectiveMainloop::Sm1xxBlkScaledConfig; + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + using RuntimeDataTypeA = typename Operator::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::CollectiveMainloop::RuntimeDataTypeB; + + +private: + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockScaledGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = LayoutTypeID::kRowMajor; + description_.SFA.alignment = 128; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = LayoutTypeID::kRowMajor; + description_.SFB.alignment = 128; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFVecSize = SFVecSize; + + description_.SFD = make_TensorDescription(128); + description_.EpilogueSFVecSize = SFD_VectorSize; + + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockScaledGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockScaledGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockScaledGemmArguments const &arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockScaledGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } else { + assert("Invalid input datatype."); + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + + operator_args.mainloop.runtime_data_type_a = mapping(arguments->runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments->runtime_input_datatype_b); + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockScaledGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..00347a993e29035e58401e69698267045b399f7d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/blockwise_gemm_operation_3x.hpp @@ -0,0 +1,429 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "gemm_operation_3x.hpp" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BlockwiseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::CollectiveMainloop::ElementA; + using ElementSFA = typename Operator::ElementAccumulator; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::CollectiveMainloop::ElementB; + using ElementSFB = typename Operator::ElementAccumulator; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + +private: + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + description_.kind = OperationKind::kBlockwiseGemm; + description_.SFA.element = NumericTypeMap::kId; + description_.SFA.layout = size<0,1>(typename CollectiveMainloop::LayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + description_.SFA.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFA.log_extent_range = 32; + description_.SFA.log_stride_range = 32; + + description_.SFB.element = NumericTypeMap::kId; + description_.SFB.layout = size<0,1>(typename CollectiveMainloop::LayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + description_.SFB.alignment = CollectiveMainloop::AlignmentSFA; + description_.SFB.log_extent_range = 32; + description_.SFB.log_stride_range = 32; + + description_.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + description_.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + description_.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.gemm_kind = GemmKind::kUniversal; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + BlockwiseGemmDescription const& get_gemm_description() const { + return description_; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, BlockwiseGemmArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, BlockwiseGemmArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + BlockwiseGemmArguments const *arguments) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_SFA = static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = static_cast(arguments->SFB); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + operator_args.mainloop.layout_SFA = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(operator_args.problem_shape); + operator_args.mainloop.layout_SFB = Operator::CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(operator_args.problem_shape); + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + BlockwiseGemmArguments const *arguments = + static_cast(arguments_ptr); + + if (arguments->sf_m_vec_size != description_.SFMVecSize && arguments->sf_m_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_n_vec_size != description_.SFNVecSize && arguments->sf_n_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + if (arguments->sf_k_vec_size != description_.SFKVecSize && arguments->sf_k_vec_size != 0) { + return Status::kErrorInvalidProblem; + } + + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..3b1a1584db92c4379e04c84a2658f79313b3eaad --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv2d_operation.h @@ -0,0 +1,650 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/kernel/default_conv2d_dgrad.h" +#include "cutlass/conv/kernel/default_conv2d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv2dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv2dOperationBase(char const *name = "unknown_conv2d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv2d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + // TODO: Add split k mode Serial and parallel to convolutions + // description_.split_k_mode = Operator::kSplitK ? SplitKMode::kSerial : SplitKMode::kNone; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv2dOperation : public Conv2dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv2dOperation(char const *name = "unknown_conv2d_fprop") : Conv2dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// DirectConv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConv2dOperation : public Conv2dOperation { +public: + + using Operator = Operator_; + using Base = Conv2dOperation; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + DirectConv2dOperation(char const *name = "unknown_direct)conv2d_fprop") : Conv2dOperation(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv2dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_reordered_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + operator_args.ref_reordered_B.reset(static_cast(const_cast(arguments->reordered_B))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv2dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv2dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv2dOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fe402c4494c27a882bf42f867a708e954ee87dc0 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv3d_operation.h @@ -0,0 +1,389 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/default_conv3d_dgrad.h" +#include "cutlass/conv/kernel/default_conv3d_wgrad.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Conv3dOperationBase : public Operation { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = Operator::kIteratorAlgorithm; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ConvDescription description_; + +public: + + /// Constructor + Conv3dOperationBase(char const *name = "unknown_conv3d") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kConv3d; + description_.conv_dim = Operator::kConvDim; + + description_.iterator_algorithm = IteratorAlgorithmMap::kId; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::UnderlyingKernel::WarpCount::kM, + Operator::UnderlyingKernel::WarpCount::kN, + Operator::UnderlyingKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Conv2d library operation class for cutlass profiler +// +/////////////////////////////////////////////////////////////////////////////////////////////////// +template +class Conv3dOperation : public Conv3dOperationBase { +public: + + using Operator = Operator_; + + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + using OperatorArguments = typename Operator::Arguments; + +public: + /// Constructor + Conv3dOperation(char const *name = "unknown_conv3d_fprop") : Conv3dOperationBase(name) { + this->description_.conv_kind = ConvKindMap::kId; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + Conv3dConfiguration const *configuration) { + + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = + { + nullptr, + LayoutA::packed(implicit_gemm_tensor_a_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_B = + { + nullptr, + LayoutB::packed(implicit_gemm_tensor_b_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_C = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.ref_D = + { + nullptr, + LayoutC::packed(implicit_gemm_tensor_c_extent(kConvolutionalOperator, configuration->problem_size)) + }; + + operator_args.split_k_mode = configuration->split_k_mode; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ConvArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(const_cast(arguments->A))); + operator_args.ref_B.reset(static_cast(const_cast(arguments->B))); + operator_args.ref_C.reset(static_cast(const_cast(arguments->C))); + operator_args.ref_D.reset(static_cast(const_cast(arguments->D))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + Conv3dConfiguration const *configuration = + static_cast(configuration_ptr); + + ConvArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + //std::cout << "run library::Conv3dOperation" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Conv3dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Conv3dOperation::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " split_k_mode: " + << (operator_args.split_k_mode == cutlass::conv::SplitKMode::kSerial ? "serial" : "parallel") << std::endl + << " epilogue (alpha, beta): " + << operator_args.output_op.alpha << ", " + << operator_args.output_op.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ref_A.data() << ", {" + << operator_args.ref_A.stride(0) << ", " + << operator_args.ref_A.stride(1) << ", " + << operator_args.ref_A.stride(2) << ", " + << operator_args.ref_A.stride(3) << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ref_B.data() << ", {" + << operator_args.ref_B.stride(0) << ", " + << operator_args.ref_B.stride(1) << ", " + << operator_args.ref_B.stride(2) << ", " + << operator_args.ref_B.stride(3) << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ref_C.data() << ", {" + << operator_args.ref_C.stride(0) << ", " + << operator_args.ref_C.stride(1) << ", " + << operator_args.ref_C.stride(2) << ", " + << operator_args.ref_C.stride(3) << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ref_D.data() << ", {" + << operator_args.ref_D.stride(0) << ", " + << operator_args.ref_D.stride(1) << ", " + << operator_args.ref_D.stride(2) << ", " + << operator_args.ref_D.stride(3) << "}" << std::endl; + } +}; + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..86c1513e9c934c22e281cf37e1c5e7783e23d305 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/conv_operation_3x.hpp @@ -0,0 +1,980 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/trace.h" +#include +#include +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) +#include +#endif + +namespace cutlass::library { + +namespace detail { + +template +constexpr cute::array +vector_to_array_strides_helper(const std::vector& v, + std::index_sequence) +{ + return {v[(sizeof...(Indices) - 1u) - Indices]..., ValueType(1)}; +} + +template +cute::array +vector_to_array_strides(const std::vector& v, std::integral_constant) +{ + static_assert(Size != 0); + CUTLASS_ASSERT(v.size() + 1u == Size); + return vector_to_array_strides_helper(v, std::make_index_sequence{}); +} + +template +constexpr cute::array +coord_to_array_strides_helper( + const ::cutlass::Coord coord, + std::index_sequence) +{ + return {int64_t(coord[(sizeof...(Indices) - 1u) - Indices])..., int64_t(1)}; +} + +template +cute::array +coord_to_array_strides(const ::cutlass::Coord& coord) +{ + static_assert(Rank >= 0); + return coord_to_array_strides_helper(coord, std::make_index_sequence{}); +} + +} // namespace detail + +// Tells the profiler about CUTLASS 3's 2-D and 3-D convolutions. +// For CUTLASS 2's 2-D convolutions, see Conv2dOperation. +// For CUTLASS 2's 3-D convolutions, see Conv3dOperation. +template +class ConvOperation3x : public Operation { +public: + using Operator = Operator_; + + static_assert(Operator::NumSpatialDimensions == 2 || + Operator::NumSpatialDimensions == 3, + "The profiler currently only supports convolutions with 2 or 3 spatial dimensions."); + using LayoutA = cute::conditional_t + >; + using LayoutB = LayoutA; + using LayoutC = LayoutA; + + using ElementA = typename Operator::ElementA; + using ElementB = typename Operator::ElementB; + using ElementC = typename Operator::ElementC; + using ElementD = typename Operator::ElementD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static cutlass::conv::Operator const kConvolutionalOperator = Operator::kConvolutionalOperator; + + ConvOperation3x(const char* name = "unknown_cutlass_3_conv") { + // Initialize OperationDescription (the base class) + description_.name = name; + description_.provider = Provider::kCUTLASS; + + if constexpr (Operator::NumSpatialDimensions == 2) { + description_.kind = OperationKind::kConv2d; + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + description_.kind = OperationKind::kConv3d; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationID::kMultiplyAdd; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + // Initialize ConvDescription (the subclass) + + // kConvDim does not exist in Operator for CUTLASS 3 convolutions. + // For CUTLASS 2 convolutions, it is the number of spatial dimensions. + description_.conv_dim = Operator::NumSpatialDimensions; + description_.conv_kind = ConvKindMap::kId; + + description_.iterator_algorithm = {}; + + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.element_epilogue = NumericTypeMap::kId; + } + + ~ConvOperation3x() override = default; + + OperationDescription const& description() const override { + return static_cast(description_); + } + +private: + Status update_operator_arguments_from_configuration_2d_or_3d( + typename Operator::Arguments& out_args, + void const* configuration) const { + Status status = Status::kInvalid; + + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + // tools/library/include/cutlass/library/library.h + // defines Conv2dConfiguration. + // tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h + // uses Conv2dConfiguration. + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + status = update_operator_arguments_from_configuration(out_args, *conf_ptr); + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + + return status; + } + +public: + Status can_implement( + void const* configuration, + void const* arguments) const override { + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp accesses "configuration" as + // GemmUniversalConfiguration (which lives in + // tools/library/include/cutlass/library/library.h) and + // "arguments" as GemmUniversalArguments (which lives in + // tools/library/include/cutlass/library/library.h). + // Those things don't apply to convolutions. + // Despite the existence of ConvUniversal, there's no + // corresponding "ConvUniversalConfiguration" or + // "ConvUniversalArguments." + + CUTLASS_ASSERT(configuration != nullptr); + CUTLASS_ASSERT(arguments != nullptr); + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_configuration_2d_or_3d failed"); + return status; + } + + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("*** can_implement: update_operator_arguments_from_arguments failed"); + return status; + } + + return Operator::can_implement(out_args); + } + + uint64_t get_host_workspace_size(void const* /* configuration */) const override { + return sizeof(Operator); + } + + uint64_t get_device_workspace_size( + void const* configuration, + void const* arguments = nullptr) const override + { + // This presumes that at least one of configuration or arguments is nonnull. + Status status = Status::kInvalid; + + // gemm_operation_3x.hpp has get_device_workspace_size return 0 on + // error. It's not clear that this is what we want -- perhaps we + // should return something like expected? -- but + // it's the only option that preserves the current interface. + constexpr uint64_t error_indication = 0; + + typename Operator::Arguments out_args{}; + if (configuration != nullptr) { + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + return error_indication; + } + } + if (arguments != nullptr) { + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return error_indication; + } + } + + if (status == Status::kSuccess) { + return static_cast(Operator::get_workspace_size(out_args)); + } + else { + return error_indication; + } + } + + Status initialize( + void const* configuration, + void* host_workspace, + void* /* device_workspace */ = nullptr, + cudaStream_t stream = nullptr) const override + { + Status status = Status::kInvalid; + + if (configuration == nullptr) { + CUTLASS_TRACE_HOST("Input configuration is null."); + return Status::kInvalid; + } + + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_configuration_2d_or_3d(out_args, configuration); + if (status != Status::kSuccess) { + // Any kind of failure invalidates the last successful configuration. + clear_last_successful_config(); + return status; + } + else { + set_last_successful_config(configuration); + } + + if (host_workspace == nullptr) { + CUTLASS_TRACE_HOST("host_workspace is null."); + return Status::kInvalid; + } + (void) new (host_workspace) Operator; + return status; + + // CUTLASS 2 convolutions call the Operator's initialize function + // here, like this. + // + //return op->initialize(args, device_workspace, stream); + // + // CUTLASS 3 convolutions (ConvUniversal), like CUTLASS 3 Gemms + // (GemmUniversal), lack an "initialize" member function. + } + + Status run( + void const* arguments, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override + { + auto status = Status::kInvalid; + + // The Operator doesn't appear to save the last configuration (it + // doesn't have a way to do that, since it lacks an initialize() + // member function), so we have to use the stored configuration + // from the last successful initialize() call (if any). + typename Operator::Arguments out_args{}; + status = update_operator_arguments_from_stored_configuration(out_args); + if (status != Status::kSuccess) { + CUTLASS_TRACE_HOST("Updating from previous successful configuration failed."); + return status; + } + + if (arguments == nullptr) { + CUTLASS_TRACE_HOST("Input argument 'arguments' is null."); + return Status::kInvalid; + } + auto* in_args_ptr = reinterpret_cast(arguments); + status = update_operator_arguments_from_arguments(out_args, *in_args_ptr); + if (status != Status::kSuccess) { + return status; + } + + auto* op = reinterpret_cast(host_workspace); + return op->run(out_args, device_workspace, stream, nullptr, in_args_ptr->use_pdl); + } + +private: + ConvDescription description_; + // Result of initialize() calling + // update_operator_arguments_from_configuration() successfully. + // This is needed because run() doesn't take a configuration, just + // arguments, and the kernel doesn't appear to save the + // configuration from the last initialize() call. + // + // Unfortunately, this must be declared mutable, because it must be + // set in initialize(), and initialize() is inherited as const. + mutable std::variant< + std::monostate, + Conv2dConfiguration, + Conv3dConfiguration> last_successful_config_{std::monostate{}}; + + // Clear the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void clear_last_successful_config() const { + last_successful_config_ = std::monostate{}; + } + + // Set the last configuration resulting from a successful initialize() call. + // + // Unfortunately, this must be declared const, because initialize() is. + void set_last_successful_config(void const* configuration) const { + CUTLASS_ASSERT(configuration != nullptr); + + if constexpr (Operator::NumSpatialDimensions == 2) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv2d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } else if constexpr (Operator::NumSpatialDimensions == 3) { + CUTLASS_ASSERT(description_.kind == OperationKind::kConv3d); + auto* conf_ptr = reinterpret_cast(configuration); + last_successful_config_ = *conf_ptr; + } + else { + static_assert(::cutlass::detail::dependent_false, + "This class currently only supports 2-D and 3-D convolutions."); + } + } + + // Whether a configuration from a successful initialize() call exists. + bool last_successful_config_exists() const { + return not std::holds_alternative(last_successful_config_); + } + + // Visitor for update_operator_arguments_from_stored_configuration. + struct ConfigurationVisitor { + typename Operator::Arguments& out_args; + + Status operator() (std::monostate const&) const { + CUTLASS_TRACE_HOST("No successful previous configuration exists. " + "One cause is calling run() before a successful initialize() call."); + return Status::kInvalid; + } + Status operator() (Conv2dConfiguration const& conf2d) const { + return update_operator_arguments_from_configuration(out_args, conf2d); + } + Status operator() (Conv3dConfiguration const& conf3d) const { + return update_operator_arguments_from_configuration(out_args, conf3d); + } + }; + + // Like update_operator_arguments_from_configuration, but on the + // stored configuration from the last successful initialize() call, + // if any. If there was no last successful initialize() call, + // then return Status::kInvalid. + // + // Unfortunately, this must be declared const, because run() is. + Status update_operator_arguments_from_stored_configuration( + typename Operator::Arguments& out_args) const + { + return std::visit(ConfigurationVisitor{out_args}, last_successful_config_); + } + + template + struct UpdateFusionArgs { + static Status update_( + FusionArgs const&, + ConvArguments const&) + { + // For custom EVT, it is the user's responsibility to ensure + // that alpha and beta are updated appropriately. + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_( + FusionArgs& fusion_args, + ConvArguments const& arguments) + { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv2dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv2dConfiguration)\n"); +#endif + using detail::vector_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 2) { + CUTLASS_TRACE_HOST("You can only use Conv2dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 2."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv2dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + constexpr auto the_stride_size = std::integral_constant{}; + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (auto const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < stride.size(); ++k) { + std::cerr << stride[k]; + if (k + 1u < stride.size()) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(config.stride_a, "config.stride_a"); + print_stride(config.stride_b, "config.stride_b"); + print_stride(config.stride_c, "config.stride_c"); +#endif + + // Conv2dConfiguration stores the strides as std::vector, + // so the code needs to check the run-time vector lengths. + if (config.stride_a.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_a.size() + 1u = " + << (config.stride_a.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_b.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_b.size() + 1u = " + << (config.stride_b.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + if (config.stride_c.size() + 1u != stride_size) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) + std::ostringstream os; + os << "config.stride_c.size() + 1u = " + << (config.stride_c.size() + 1u) + << " != num_spatial_dims + 2u = " << stride_size; + CUTLASS_TRACE_HOST( os.str() ); +#endif + return Status::kInvalid; + } + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = vector_to_array_strides(config.stride_a, the_stride_size); + const TensorStride stride_B = vector_to_array_strides(config.stride_b, the_stride_size); + const TensorStride stride_C = vector_to_array_strides(config.stride_c, the_stride_size); + + // cutlass::library::Conv2dConfiguration has no member stride_d. + // The code below imitates the testbed, + // which just sets D's strides to C's strides. + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // This means that stride_act isn't always config.stride_A, + // depending on Fprop / Dgrad / Wgrad. The code here "undoes" + // the logic in Conv2dWorkspace::set_stride_vector so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_h, pad_w}, + /* upper_padding = */ {pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + static Status update_operator_arguments_from_configuration( + typename Operator::Arguments& out_args, + Conv3dConfiguration const& config) + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperator3x::" + "update_operator_arguments_from_configuration" + "(Conv3dConfiguration)\n"); +#endif + using detail::coord_to_array_strides; + + constexpr int num_spatial_dims = Operator::NumSpatialDimensions; + if constexpr (num_spatial_dims != 3) { + CUTLASS_TRACE_HOST("You can only use Conv3dConfiguration " + "with an Operator whose NumSpatialDimensions is exactly 3."); + return Status::kInvalid; + } + else { + // Convolutions split the metadata (in Conv3dConfiguration) from + // the data (ConvArguments, which only has pointers and a single + // enum value). Thus, this class will need both the + // configuration and the (user's input) arguments to set up the + // kernel's arguments. This function can fill in what the + // configuration has now, but the class will need the user's + // input arguments later. + if (config.split_k_mode != conv::SplitKMode::kSerial) { + CUTLASS_TRACE_HOST("CUTLASS 3 convolutions currently only support split_k_mode = kSerial."); + return Status::kInvalid; + } + // config.problem_size.split_k_slices is only meaningful if + // split_k_mode != kSerial. If this code later supports other + // split_k_mode values, then it will also need to read + // split_k_slices. + + const int N = config.problem_size.N; + const int D = config.problem_size.D; + const int H = config.problem_size.H; + const int W = config.problem_size.W; + const int C = config.problem_size.C; + const int K = config.problem_size.K; + const int T = config.problem_size.T; + const int R = config.problem_size.R; + const int S = config.problem_size.S; + const int pad_d = config.problem_size.pad_d; + const int pad_h = config.problem_size.pad_h; + const int pad_w = config.problem_size.pad_w; + const int traversal_stride_d = config.problem_size.stride_d; + const int traversal_stride_h = config.problem_size.stride_h; + const int traversal_stride_w = config.problem_size.stride_w; + const int dilation_d = config.problem_size.dilation_d; + const int dilation_h = config.problem_size.dilation_h; + const int dilation_w = config.problem_size.dilation_w; + + // CUTLASS 3's implicit GEMM convolution kernels currently only + // support cross correlation (passing over the activation and + // filter tensors in the same order). The convolution mode is + // future work. + const auto mode = config.problem_size.mode; + if (mode != cutlass::conv::Mode::kCrossCorrelation) { + CUTLASS_TRACE_HOST("Convolution modes other than kCrossCorrelation " + "are not currently supported."); + return Status::kInvalid; + } + + using Stride = cutlass::layout::TensorNDHWC::Stride; + static_assert(std::is_same_v>); + + const cutlass::library::ConvKind conv_kind = [] () { + constexpr cutlass::conv::Operator op = Operator::DispatchPolicy::ConvOp; + if constexpr (op == cutlass::conv::Operator::kFprop) { + return library::ConvKind::kFprop; + } + else if constexpr (op == cutlass::conv::Operator::kDgrad) { + return library::ConvKind::kDgrad; + } + else /* if constexpr (op == cutlass::conv::Operator::kWgrad) */ { + return library::ConvKind::kWgrad; + } + } (); + const Stride input_stride_a = config.layout_a(conv_kind).stride(); + const Stride input_stride_b = config.layout_b(conv_kind).stride(); + const Stride input_stride_c = config.layout_c(conv_kind).stride(); + +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + constexpr size_t stride_size = size_t(num_spatial_dims) + 2u; + std::cerr << " num_spatial_dims = " << num_spatial_dims << "\n" + << " stride_size = " << stride_size << "\n"; + auto print_stride = [] (Stride const& stride, char const variable_name[]) { + std::cerr << " " << variable_name << ": ["; + for (size_t k = 0; k < Stride::kRank; ++k) { + std::cerr << stride[static_cast(k)]; + if (k + 1u < Stride::kRank) { + std::cerr << ", "; + } + } + std::cerr << "]\n"; + }; + print_stride(input_stride_a, "input_stride_a"); + print_stride(input_stride_b, "input_stride_b"); + print_stride(input_stride_c, "input_stride_c"); +#endif + // Conv3dConfiguration stores the strides as Coord (with + // compile-time size), so there's no need to check sizes here + // (unlike Conv2dConfiguration, which stores strides as + // std::vector). + + constexpr cutlass::conv::Operator conv_op = Operator::DispatchPolicy::ConvOp; + using problem_shape_type = + cutlass::conv::ConvProblemShape; + // cute::array; must convert to the kernel's native strides + using TensorStride = typename problem_shape_type::TensorStride; + + const TensorStride stride_A = coord_to_array_strides(input_stride_a); + const TensorStride stride_B = coord_to_array_strides(input_stride_b); + const TensorStride stride_C = coord_to_array_strides(input_stride_c); + + const int num_groups = config.problem_size.groups; + if (num_groups != 1) { + CUTLASS_TRACE_HOST("CUTLASS 3 kernels currently only support groups = 1."); + return Status::kInvalid; + } + // ConvProblemShape is how CUTLASS 3 kernels represent + // convolution problems. ConvProblemShape's constructors take + // shape_act, stride_act, shape_flt, and stride_flt, and set + // shape_A, stride_A, shape_B, stride_B, shape_C, and stride_C + // according to Fprop / Dgrad / Wgrad. + // + // Conv3dConfiguration differs a bit from Conv2dConfiguration, + // but the idea is the same: the "input_stride_a" from config + // depends on conv_kind (Fprop, Dgrad, or Wgrad), so stride_act + // isn't always input_stride_a. Analogously, stride_flt isn't + // always input_stride_b. The code here "undoes" the logic in + // config.layout_a(conv_kind) and config.layout_b(conv_kind) + // (analogous to Conv2dWorkspace::set_stride_vector) so that we + // can recover the strides of the activation and filter tensors. + // It doesn't need to worry about the so-called "output" tensor + // (which might not be C), as ConvProblemShape's constructor + // figures out its shapes and strides. + using TensorExtent = typename problem_shape_type::TensorExtent; + TensorExtent shape_act{N, D, H, W, C}; + auto stride_act = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_A; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_C; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_B; + } + } (); + TensorExtent shape_flt{K, T, R, S, C}; + auto stride_flt = [&] () { + // Some compilers consider conv_op (defined above), as + // captured by this lambda, as "not a constant expression." + constexpr auto conv_kind = Operator::DispatchPolicy::ConvOp; + if constexpr (conv_kind == cutlass::conv::Operator::kFprop) { + return stride_B; + } + else if constexpr (conv_kind == cutlass::conv::Operator::kDgrad) { + return stride_B; + } + else { // conv_kind == cutlass::conv::Operator::kWgrad + return stride_C; + } + } (); + + problem_shape_type problem_shape( + /* mode = */ mode, + /* shape_act = */ shape_act, + /* stride_act = */ stride_act, + /* shape_flt = */ shape_flt, + /* stride_flt = */ stride_flt, + /* lower_padding = */ {pad_d, pad_h, pad_w}, + /* upper_padding = */ {pad_d, pad_h, pad_w}, + /* traversal_stride = */ {traversal_stride_d, traversal_stride_h, traversal_stride_w}, + /* dilation = */ {dilation_d, dilation_h, dilation_w}, + num_groups); + out_args.problem_shape = problem_shape; + + // ConvProblemShape's constructor sets its shape_C member. +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n problem_shape.shape_C: "); + print(problem_shape.shape_C); + printf("\n problem_shape.stride_C: "); + print(problem_shape.stride_C); + printf("\n"); +#endif + // Initialization of C's and D's strides follows the CUTLASS 3 + // convolutions testbed (test/unit/conv/device_3x/testbed_conv.hpp). + { + using StrideC = typename Operator::ConvKernel::StrideC; + using StrideD = typename Operator::ConvKernel::StrideD; + auto stride_C = StrideC{}; + auto stride_D = StrideD{}; + + if constexpr (conv_op == cutlass::conv::Operator::kWgrad) { + stride_C = cutlass::make_cute_packed_stride( + StrideC{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); + stride_D = cutlass::make_cute_packed_stride( + StrideD{}, problem_shape.shape_C, problem_shape.stride_C, conv_op); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::cerr << " Wgrad: stride_C: " << stride_C << "\n"; +#endif + } + else { + cute::for_each(cute::make_seq(StrideC{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_C_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_C): " + << stride_C_i << "\n"; +#endif + cute::get<0, i>(stride_C) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + cute::for_each(cute::make_seq(StrideD{})>{}, [&](auto i) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + const auto stride_D_i = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + std::cerr << " Fprop or Dgrad: get<0, " << i << ">(stride_D): " + << stride_D_i << "\n"; +#endif + cute::get<0, i>(stride_D) = problem_shape.stride_C[problem_shape_type::RankT-2-i]; + }); + } + out_args.epilogue.dC = stride_C; + out_args.epilogue.dD = stride_D; + } + return Status::kSuccess; + } + } + + Status update_operator_arguments_from_arguments( + typename Operator::Arguments& out_args, + ConvArguments const& in_args) const + { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("ConvOperation3x::update_operator_arguments_from_arguments\n"); +#endif + auto status = UpdateFusionArgs::update_( + out_args.epilogue.thread, in_args); + if (status != Status::kSuccess) { + return status; + } + + out_args.mainloop.ptr_A = reinterpret_cast(in_args.A); + out_args.mainloop.ptr_B = reinterpret_cast(in_args.B); + + out_args.epilogue.ptr_C = reinterpret_cast(in_args.C); + out_args.epilogue.ptr_D = reinterpret_cast(in_args.D); + + return Status::kSuccess; + } +}; + +} // namespace cutlass::library diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..880cb4bf34b1f3d946e1dc86b80806309bb2b3c1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation.h @@ -0,0 +1,1408 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/gemm/device/gemm_sparse.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/gemm/device/gemm_batched.h" +#include "cutlass/gemm/device/gemm_array.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/default_gemm_universal.h" +#include "cutlass/gemm/kernel/default_gemm_planar_complex_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmOperationBase(char const *name = "unknown_gemm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kGemm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::GemmKernel::WarpCount::kM, + Operator::GemmKernel::WarpCount::kN, + Operator::GemmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kGemm; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + + operator_args.split_k_slices = configuration->split_k_slices; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmSparseOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementE = typename Operator::ElementE; + using LayoutE = typename Operator::LayoutE; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmSparseOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.kind = OperationKind::kSparseGemm; + this->description_.gemm_kind = GemmKind::kSparse; + this->description_.E = make_TensorDescription(Operator::kAlignmentE); + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SparseGemmConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.ref_A = {nullptr, configuration->lda}; + operator_args.ref_B = {nullptr, configuration->ldb}; + operator_args.ref_C = {nullptr, configuration->ldc}; + operator_args.ref_D = {nullptr, configuration->ldd}; + operator_args.ref_E = {nullptr, configuration->lde}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SparseGemmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.ref_A.reset(static_cast(arguments->A)); + operator_args.ref_B.reset(static_cast(arguments->B)); + operator_args.ref_C.reset(static_cast(arguments->C)); + operator_args.ref_D.reset(static_cast(arguments->D)); + operator_args.ref_E.reset(static_cast(arguments->E)); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SparseGemmConfiguration const *configuration = + static_cast(configuration_ptr); + + SparseGemmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + return op->run(stream); + } + + void print_operator_args(OperatorArguments &operator_args) const { +#if 0 + std::cout << "GemmOperation::OperatorArguments" << std::endl; + std::cout << " problem_size: " << operator_args.problem_size.m() << ", "<< operator_args.problem_size.n() << "," << operator_args.problem_size.k() << std::endl; + std::cout << " alpha: " << operator_args.epilogue.alpha << std::endl; + std::cout << " alpha_ptr: " << operator_args.epilogue.alpha_ptr << std::endl; + std::cout << " beta: " << operator_args.epilogue.beta << std::endl; + std::cout << " beta_ptr: " << operator_args.epilogue.beta_ptr << std::endl; + std::cout << " ref_A.data(): " << operator_args.ref_A.data() << std::endl; + std::cout << " ref_A.stride: " << operator_args.ref_A.stride(0) << std::endl; + std::cout << " ref_B.data(): " << operator_args.ref_B.data() << std::endl; + std::cout << " ref_B.stride: " << operator_args.ref_B.stride(0) << std::endl; + std::cout << " ref_C.data(): " << operator_args.ref_C.data() << std::endl; + std::cout << " ref_C.stride: " << operator_args.ref_C.stride(0) << std::endl; +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversalOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmUniversalOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmUniversalConfiguration const *configuration) { + + operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = (configuration->lda); + operator_args.ldb = (configuration->ldb); + operator_args.ldc = (configuration->ldc); + operator_args.ldd = (configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplex; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kBatched; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.batch_stride_A = arguments->batch_stride_A_real; + operator_args.batch_stride_A_imag = arguments->batch_stride_A_imag; + operator_args.batch_stride_B = arguments->batch_stride_B_real; + operator_args.batch_stride_B_imag = arguments->batch_stride_B_imag; + operator_args.batch_stride_C = arguments->batch_stride_C_real; + operator_args.batch_stride_C_imag = arguments->batch_stride_C_imag; + operator_args.batch_stride_D = arguments->batch_stride_D_real; + operator_args.batch_stride_D_imag = arguments->batch_stride_D_imag; + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmPlanarComplexArrayOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmPlanarComplexArrayOperation(char const *name = "unknown_gemm"): GemmOperationBase(name) { + + this->description_.gemm_kind = GemmKind::kPlanarComplexArray; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayConfiguration const *configuration) { + + operator_args.mode = cutlass::gemm::GemmUniversalMode::kArray; + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda_real = configuration->lda_real; + operator_args.lda_imag = configuration->lda_imag; + operator_args.ldb_real = configuration->ldb_real; + operator_args.ldb_imag = configuration->ldb_imag; + operator_args.ldc_real = configuration->ldc_real; + operator_args.ldc_imag = configuration->ldc_imag; + operator_args.ldd_real = configuration->ldd_real; + operator_args.ldd_imag = configuration->ldd_imag; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmPlanarComplexArrayArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast const *>(arguments->alpha), + *static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast const *>(arguments->alpha), + static_cast const *>(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A_real = arguments->A_real; + operator_args.ptr_A_imag = arguments->A_imag; + operator_args.ptr_B_real = arguments->B_real; + operator_args.ptr_B_imag = arguments->B_imag; + operator_args.ptr_C_real = arguments->C_real; + operator_args.ptr_C_imag = arguments->C_imag; + operator_args.ptr_D_real = arguments->D_real; + operator_args.ptr_D_imag = arguments->D_imag; + + operator_args.ptr_M = arguments->M; + operator_args.ptr_N = arguments->N; + operator_args.ptr_K = arguments->K; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmPlanarComplexArrayConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmPlanarComplexArrayArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmGroupedOperation : public GemmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = ElementC; + using LayoutD = LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + GemmGroupedOperation(char const *name = "unknown_gemm"): + GemmOperationBase(name) { + + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.provider = Provider::kCUTLASS; + this->threadblock_count = Operator::sufficient(); + + this->description_.gemm = GemmOperationBase::description_; + this->description_.gemm.gemm_kind = GemmKind::kGrouped; + this->description_.tile_description = this->description_.gemm.tile_description; + } + + /// Returns the description of the GroupedGEMM operation + virtual OperationDescription const & description() const override final { + return description_; + } + + +private: + int threadblock_count; + GroupedGemmDescription description_; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + Status construct_arguments_( + OperatorArguments &op_args, + GemmGroupedConfiguration const *config) const { + + op_args.problem_count = config->problem_count; + op_args.threadblock_count = threadblock_count; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments &op_args, + GemmGroupedArguments const *arguments) const { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice) { + + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + + op_args.output_op = params; + } + else { + return Status::kErrorInvalidProblem; + } + + op_args.threadblock_count = threadblock_count; + op_args.problem_count = arguments->problem_count; + op_args.problem_sizes = arguments->problem_sizes; + + op_args.ptr_A = static_cast(arguments->ptr_A); + op_args.ptr_B = static_cast(arguments->ptr_B); + op_args.ptr_C = static_cast(arguments->ptr_C); + op_args.ptr_D = static_cast(arguments->ptr_D); + + op_args.lda = arguments->lda; + op_args.ldb = arguments->ldb; + op_args.ldc = arguments->ldc; + op_args.ldd = arguments->ldd; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + GemmGroupedConfiguration const *configuration = + static_cast(configuration_ptr); + + GemmGroupedArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2c1d17943f11fe8126b3070c3fcead5598e2d207 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/gemm_operation_3x.hpp @@ -0,0 +1,714 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmOperation3xBase : public Operation { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + // assuming all tensors use same type for StrideIndex + using StrideIndex = typename Operator::LayoutA::Index; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + +protected: + GemmDescription description_; + +public: + + /// Constructor + GemmOperation3xBase(char const *name = "unknown_gemm", GemmKind gemm_kind_ = GemmKind::kGemm) { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = gemm_kind_; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + description_.tile_description.cluster_shape = make_Coord( + Operator::ClusterShape::kM, + Operator::ClusterShape::kN, + Operator::ClusterShape::kK); + } + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::WarpCount::kM, + Operator::WarpCount::kN, + Operator::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.D = make_TensorDescription(Operator::kAlignmentD); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + /// Returns the description of the GEMM operation + GemmDescription const& get_gemm_description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class GemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + +public: + + /// Constructor + GemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) { + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + dim3 cluster_dims( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{})); + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + } + } + +private: + int max_active_clusters{}; + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + template class Policy, int Stages, class ClusterShape, class KernelSchedule> + static constexpr bool is_sm90_mixed_dtype_mainloop_(Policy policy) { + return (cute::is_same_v, + cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput>); + } + + template + static constexpr bool is_sm90_mixed_dtype_mainloop_(DispatchPolicy) { + return false; + } + + template < + typename ElementWide, + typename ElementNarrow, + typename ElementScaleMainloop, + class ActualStrideAB, + Sm90MixedInputWiderOperand wider_operand, + bool is_n4w8, + typename ElementScale, + typename ElementZero, + class Layout_SZ> + static void dequantize_encode_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + cudaStream_t stream, + const int &problem_mn, + const int &problem_k, + const int &options_l, + const int &options_g, + ElementScale *ptr_S, + ElementZero *ptr_Z, + const size_t &SZ_size, + Layout_SZ layout_SZ + ) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + auto *ptr_dequantized_AB = static_cast(arguments->dequantized_AB); + const ElementNarrow *ptr_AB = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + ptr_AB = static_cast(arguments->B); + } + else { + ptr_AB = static_cast(arguments->A); + } + dequantize(ptr_dequantized_AB, ptr_AB, layout_AB, ptr_S, ptr_Z, layout_SZ, options_g, stream); + if constexpr(is_n4w8) { + size_t AB_size = cute::size(layout_AB); + cutlass::int4b_t *encoded_AB = static_cast(arguments->encoded_AB); + unified_encode_int4b(ptr_AB, encoded_AB, AB_size); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = static_cast(encoded_AB); + } + else { + operator_args.mainloop.ptr_A = static_cast(encoded_AB); + } + ElementScaleMainloop *ptr_packed_Scale = static_cast(arguments->packed_Scale); + pack_scale_fp8(ptr_S, ptr_packed_Scale, SZ_size); + } + } + + template < + typename ElementAB, + class ActualStrideAB, + class LayoutAB_Reordered, + class LayoutAtomQuant, + Sm90MixedInputWiderOperand wider_operand> + static void handle_shuffle_tensor_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + const int &problem_mn, + const int &problem_k, + const int &options_l) { + + auto shape_AB = cute::make_shape(problem_mn, problem_k, options_l); + auto stride_AB = cutlass::make_cute_packed_stride(ActualStrideAB{}, shape_AB); + auto layout_AB = cute::make_layout(shape_AB, stride_AB); + LayoutAB_Reordered layout_AB_reordered = cute::tile_to_shape(LayoutAtomQuant{}, shape_AB); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.dB = layout_AB_reordered; + } + else { + operator_args.mainloop.dA = layout_AB_reordered; + } + if (arguments->generate_dequantized_AB) { + size_t AB_size = cute::size(layout_AB); + ElementAB *AB_reordered = cutlass::device_memory::allocate(AB_size); + const ElementAB *AB_src = nullptr; + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + AB_src = static_cast(operator_args.mainloop.ptr_B); + } + else { + AB_src = static_cast(operator_args.mainloop.ptr_A); + } + reorder_tensor(AB_src, layout_AB, AB_reordered, layout_AB_reordered); + ElementAB *AB_dst = static_cast(arguments->encoded_AB); + cutlass::device_memory::copy_device_to_device(AB_dst, AB_reordered, AB_size); + cutlass::device_memory::free(AB_reordered); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + operator_args.mainloop.ptr_B = AB_dst; + } + else { + operator_args.mainloop.ptr_A = AB_dst; + } + } + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_( + OperatorArguments& operator_args, + GemmUniversalArguments const* arguments, + cudaStream_t stream = nullptr) const { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + // TODO: type erase Arguments structure in 3.0 GEMM + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments->A); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + // Stride{A,B} is a Layout if and only if: + // (1) This is a mixed dtype kernel, and + // (2) This mixed dtype kernel is using shuffling, and + // (3) sizeof(narrow_type) == 4 or 8 bits, and + // (4) sizeof(wide_type) == 16 bits. + // If A/B has the narrow data type, Stride{A/B} will be a Layout + constexpr bool is_StrideA_Layout = cute::is_layout::value; + constexpr bool is_StrideB_Layout = cute::is_layout::value; + static_assert(!(is_StrideA_Layout && is_StrideB_Layout), "Incorrect kernel configuration: StrideA and StrideB are both cute::Layout"); + if constexpr(!is_StrideA_Layout) { + operator_args.mainloop.dA = cute::make_int_tuple_from( + arguments->lda, arguments->batch_stride_A); + } + if constexpr(!is_StrideB_Layout) { + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + } + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + using MainloopPolicy = typename CollectiveMainloop::DispatchPolicy; + if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{})) { + const int problem_m = arguments->problem_size.m(); + const int problem_n = arguments->problem_size.n(); + const int problem_k = arguments->problem_size.k(); + const int options_l = arguments->batch_count; + + constexpr Sm90MixedInputWiderOperand wider_operand = + (cutlass::sizeof_bits::value > cutlass::sizeof_bits::value) ? + Sm90MixedInputWiderOperand::A : Sm90MixedInputWiderOperand::B; + using ElementWide = std::conditional_t; + using ElementNarrow = std::conditional_t; + + constexpr bool has_scale = !std::is_same_v; + constexpr bool has_zero = !std::is_same_v; + + const int options_g = problem_k; + const int scale_k = (problem_k + options_g - 1) / options_g; + + constexpr bool is_A4B8 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_A8B4 = ( + cutlass::is_same_v && + (cutlass::is_same_v || + cutlass::is_same_v)); + constexpr bool is_int4_x_fp8 = is_A4B8 || is_A8B4; + + // If this is a convert-only kernel, we still need to generate dequantized A or B for verification, + // and in this case ElementScale is the same as ElementWide + // In int4 * fp8, ElementScale is a cutlass::Array, need to take out it's real element + using DummyElementScaleMainloop = std::conditional_t< + is_int4_x_fp8, + typename cutlass::Array, + ElementWide + >; + using ElementScaleMainloop = std::conditional_t< + has_scale, + typename CollectiveMainloop::ElementScale, + DummyElementScaleMainloop + >; + using ElementScale = std::conditional_t< + has_scale, + typename UnderlyingElement::type, + ElementWide + >; + using StrideScale = typename CollectiveMainloop::StrideScale; + // In ScaleOnly mode, we have allocated the same size of memory for arguments->Z and arguments->S + using ElementZero = std::conditional_t< + has_zero, + typename CollectiveMainloop::ElementZero, + ElementScale + >; + const int SZ_1st_dim = (wider_operand == Sm90MixedInputWiderOperand::A) ? problem_n : problem_m; + const size_t SZ_size = static_cast(SZ_1st_dim * scale_k * options_l); + auto shape_SZ = cute::make_shape(SZ_1st_dim, scale_k, options_l); + ElementScale *ptr_S = static_cast(arguments->Scale); + ElementZero *ptr_Z = static_cast(arguments->Zero); + + // 1. If arguments is initialized in profiler, S and Z needs to be allocated and filled + if (arguments->generate_scale_and_zero) { + float scale_min = 1.0f, scale_max = 1.0f; + if constexpr(has_scale) { + const float elt_max_f = float(cutlass::platform::numeric_limits::max()); + // Need to fix max_dequant_val and min_dequant_val? + const float max_dequant_val = elt_max_f * 0.25f; + const float min_dequant_val = 0.5f; + scale_max = max_dequant_val / elt_max_f; + scale_min = min_dequant_val / elt_max_f; + } + uint64_t seed = 2023; + cutlass::reference::device::BlockFillRandomUniform( + ptr_S, SZ_size, seed, ElementScale(scale_max), ElementScale(scale_min)); + + // In ScaleOnly mode, set Z as zero for generating dequantized A or B + const float zero_max = has_zero ? 2.0f : 0.0f; + const float zero_min = has_zero ? -2.0f : 0.0f; + cutlass::reference::device::BlockFillRandomUniform( + ptr_Z, SZ_size, seed, ElementZero(zero_max), ElementZero(zero_min)); + } // End of "if (arguments->generate_scale_and_zero)" + + // 2. Generate the dequantized A or B for verification + if (arguments->generate_dequantized_AB) { + StrideScale stride_SZ = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + auto layout_SZ = cute::make_layout(shape_SZ, stride_SZ); + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A) { + if constexpr(is_StrideB_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of B later + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideB = typename CollectiveMainloop::StrideB; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } + else { + if constexpr(is_StrideA_Layout) { + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout of A later + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + else { + using ActualStrideA = typename CollectiveMainloop::StrideA; + dequantize_encode_( + operator_args, arguments, stream, problem_m, problem_k, options_l, options_g, ptr_S, ptr_Z, SZ_size, layout_SZ); + } + } // End of "if constexpr(wider_operand == Sm90MixedInputWiderOperand::A)" + } // End of "if (arguments->generate_dequantized_AB)" + + // 3. Put Scale and Zero in mainloop + if constexpr(has_scale) { + if constexpr(is_int4_x_fp8) { + operator_args.mainloop.ptr_S = static_cast(arguments->packed_Scale); + } + else { + operator_args.mainloop.ptr_S = static_cast(arguments->Scale); + } + operator_args.mainloop.dS = cutlass::make_cute_packed_stride(StrideScale{}, shape_SZ); + operator_args.mainloop.group_size = options_g; + if constexpr(has_zero) { + operator_args.mainloop.ptr_Z = static_cast(arguments->Zero); + } + } // End of "if constexpr(has_scale)" + + // Handle the shuffling + using ValueShuffle = std::conditional_t< + cutlass::sizeof_bits::value == 4, + cute::Layout, cute::Stride>, + cute::Layout, cute::Stride> + >; + constexpr int NumShuffleAtoms = 1; + using MmaAtomShape = cute::Layout>>; + using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); + // The generator only generates row-major A and col-major B at the moment + // Need a way to read out the actual layout and stride of A/B later + if constexpr(wider_operand == Sm90MixedInputWiderOperand::A && is_StrideB_Layout) { + using ActualLayoutB = cutlass::layout::ColumnMajor; + using ActualStrideB = cutlass::detail::TagToStrideB_t; + using LayoutB_Reordered = typename CollectiveMainloop::StrideB; + handle_shuffle_tensor_( + operator_args, arguments, problem_n, problem_k, options_l); + } + if constexpr(wider_operand == Sm90MixedInputWiderOperand::B && is_StrideA_Layout) { + using ActualLayoutA = cutlass::layout::RowMajor; + using ActualStrideA = cutlass::detail::TagToStrideA_t; + using LayoutA_Reordered = typename CollectiveMainloop::StrideA; + handle_shuffle_tensor_( + operator_args, arguments, problem_m, problem_k, options_l); + } + } // End of "if constexpr(is_sm90_mixed_dtype_mainloop_(MainloopPolicy{}))" + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability == 90) { + operator_args.hw_info.max_active_clusters = max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + [[maybe_unused]] void const *configuration_ptr, void const *arguments_ptr) const override { + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + OperatorArguments args; + + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + Status can_impl = Operator::can_implement(args); + + //return Operator::can_implement(args); + return can_impl; + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *configuration) const override { + return sizeof(Operator); + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + Operator *op = new (host_workspace) Operator; + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments args; + Status status = update_arguments_(args, static_cast(arguments_ptr), stream); + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(args, device_workspace, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..91f618d4fab74a6d43e2d82c572d215d5bea5a1c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/grouped_gemm_operation_3x.hpp @@ -0,0 +1,873 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all grouped GEMM operations in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "gemm_operation_3x.hpp" +#include "library_internal.h" + +namespace cutlass::library { + +template +class GroupedGemmOperation3xBase : public GemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + GroupedGemmOperation3xBase(char const* name = "unknown_gemm") + : GemmOperation3xBase(name, GemmKind::kGrouped) { + this->description_.kind = OperationKind::kGroupedGemm; + this->description_.name = name; + this->description_.provider = Provider::kCUTLASS; + + this->description_.gemm = GemmOperation3xBase::description_; + this->description_.tile_description = this->description_.gemm.tile_description; + }; + +public: + mutable CudaBuffer strideA_device; + mutable CudaBuffer strideB_device; + mutable CudaBuffer strideC_device; + mutable CudaBuffer strideD_device; + + /// Returns the description of the GEMM operation + virtual OperationDescription const& description() const override final { return description_; } + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const* configuration) const override final { + return sizeof(Operator); + } + +protected: + library::GroupedGemmDescription description_; + + Status initialize_strides(GemmGroupedConfiguration const& config) const { + auto const num_groups = config.problem_count; + this->strideA_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups); + this->strideB_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups); + this->strideC_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups); + this->strideD_device = + CudaBuffer(sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups); + + std::vector strideA_host(num_groups); + std::vector strideB_host(num_groups); + std::vector strideC_host(num_groups); + std::vector strideD_host(num_groups); + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + strideA_host[group_idx] = + cute::make_int_tuple_from( + config.lda[group_idx]); + strideB_host[group_idx] = + cute::make_int_tuple_from( + config.ldb[group_idx]); + strideC_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + strideD_host[group_idx] = + cute::make_int_tuple_from( + config.ldc[group_idx]); + } + CUDA_CHECK(cudaMemcpy( + this->strideA_device.data(), + strideA_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideB_device.data(), + strideB_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideB) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideC_device.data(), + strideC_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideC) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->strideD_device.data(), + strideD_host.data(), + sizeof(typename Operator::GemmKernel::InternalStrideD) * num_groups, + cudaMemcpyHostToDevice)); + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + Status update_arguments_base( + OperatorArguments& operator_args, + GemmGroupedArguments const& arguments) const { + operator_args.mode = cutlass::gemm::GemmUniversalMode::kGrouped; + operator_args.problem_shape = { + arguments.problem_count, + arguments.problem_sizes_3x, + arguments.pointer_mode == ScalarPointerMode::kHost ? arguments.problem_sizes_3x_host + : nullptr}; + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + + using RuntimeDataTypeA = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeA; + using RuntimeDataTypeB = typename Operator::GemmKernel::CollectiveMainloop::RuntimeDataTypeB; + + static_assert(cute::is_same_v, + "RuntimeDataTypeA/B should be identical, either MXF8F6F4Format or MXF4Format"); + using RuntimeDatatypeArg = RuntimeDataTypeA; + + auto mapping = [](RuntimeDatatype type) { + if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE5M2) { + return cute::UMMA::MXF8F6F4Format::E5M2; + } + else if (type == RuntimeDatatype::kE4M3) { + return cute::UMMA::MXF8F6F4Format::E4M3; + } + else if (type == RuntimeDatatype::kE3M2) { + return cute::UMMA::MXF8F6F4Format::E3M2; + } + else if (type == RuntimeDatatype::kE2M3) { + return cute::UMMA::MXF8F6F4Format::E2M3; + } + else if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF8F6F4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e4m3." << std::endl; + #endif + return cute::UMMA::MXF8F6F4Format::E4M3; + } + } + else if constexpr (cute::is_same_v) { + if (type == RuntimeDatatype::kE2M1) { + return cute::UMMA::MXF4Format::E2M1; + } + else { + #if defined(CUTLASS_DEBUG_TRACE_LEVEL) && CUTLASS_DEBUG_TRACE_LEVEL >= 1 + std::cerr << "Invalid input datatype specified. Running with e2m1." << std::endl; + #endif + return cute::UMMA::MXF4Format::E2M1; + } + } + // BlockScaled kernels receive either MXF4Format or MXF8F6F4Format runtime datatype + CUTE_GCC_UNREACHABLE; + }; + operator_args.mainloop.runtime_data_type_a = mapping(arguments.runtime_input_datatype_a); + operator_args.mainloop.runtime_data_type_b = mapping(arguments.runtime_input_datatype_b); + } + else { + operator_args.mainloop.ptr_A = static_cast(arguments.ptr_A); + operator_args.mainloop.ptr_B = static_cast(arguments.ptr_B); + } + operator_args.epilogue.ptr_C = static_cast(arguments.ptr_C); + operator_args.epilogue.ptr_D = static_cast(arguments.ptr_D); + + operator_args.mainloop.dA = + static_cast(this->strideA_device.data()); + operator_args.mainloop.dB = + static_cast(this->strideB_device.data()); + operator_args.epilogue.dC = + static_cast(this->strideC_device.data()); + operator_args.epilogue.dD = + static_cast(this->strideD_device.data()); + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments.sm_count; + if constexpr (Operator::ArchTag::kMinComputeCapability >= 90) { + operator_args.hw_info.max_active_clusters = arguments.max_active_clusters; + } + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments.swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments.raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = + dim3(arguments.cluster_shape.m(), arguments.cluster_shape.n(), arguments.cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments.cluster_shape_fallback.m(), + arguments.cluster_shape_fallback.n(), + arguments.cluster_shape_fallback.k()); + } + return Status::kSuccess; + } + + template + static Status update_fusion_args(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } +}; + +/// **** CAUTION **** +/// Unlike other operations, initialize() must be called when +/// certain arguments change. See initialize() for details. +template +class GroupedGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + +public: + GroupedGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) {} + + ~GroupedGemmUniversal3xOperation() override = default; + +private: + int max_active_clusters{}; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmGroupedArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + + /// Constructs the arguments structure given the configuration and arguments + Status + update_arguments_(OperatorArguments& operator_args, GemmGroupedArguments const* arguments) const { + + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + status = this->update_arguments_base(operator_args, *arguments); + return status; + } + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GemmGroupedArguments const* arguments = static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + Operator* op = new (host_workspace) Operator; + + auto const& config = *static_cast(configuration_ptr); + return this->initialize_strides(config); + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_workspace, stream, nullptr, args.use_pdl); + return status; + } + + // Set arguments that should only be set once before verifying or profiling the kernel. + // This should encompass any expensive operations that don't vary from run to run + // (e.g., max_active_clusters). + Status initialize_with_arguments(void* arguments_ptr) const override { + if constexpr (Operator::ArchTag::kMinComputeCapability < 90) { + return Status::kSuccess; + } + + GemmGroupedArguments* args = static_cast(arguments_ptr); + + dim3 cluster_dims; + if constexpr (cute::is_static_v) { + cluster_dims = dim3( + cute::size<0>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<1>(typename Operator::GemmKernel::ClusterShape{}), + cute::size<2>(typename Operator::GemmKernel::ClusterShape{}) + ); + } + else { + cluster_dims = dim3( + args->cluster_shape.m(), + args->cluster_shape.n(), + args->cluster_shape.k() + ); + } + + uint32_t threads_per_block = Operator::GemmKernel::MaxThreadsPerBlock; + void const* kernel_ptr = (void*)(device_kernel); + args->max_active_clusters = cutlass::KernelHardwareInfo::query_device_max_active_clusters( + cluster_dims, + threads_per_block, + kernel_ptr); + + if (args->max_active_clusters == 0) { + std::cerr << "Max Active Clusters could not be queried. " + << "Falling back to heuristics mode (static cluster shape) or preferred cluster mode.\n"; + } + + return Status::kSuccess; + } +}; + +template +class GroupedBlockScaledGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::CollectiveMainloop::ElementSF; + using ElementSFB = typename Operator::CollectiveMainloop::ElementSF; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + constexpr static int SFVecSize = TiledMma::SFVecSize; + + + static constexpr bool epilogue_scalefactor_generation = not cute::is_same_v; + static constexpr int32_t SFD_VectorSize = epilogue_scalefactor_generation ? ThreadEpilogueOp::SFVecSize : SFVecSize; + using ElementSFD = cute::conditional_t; + using LayoutSFD = cute::conditional_t; + + GroupedBlockScaledGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription block_scaled_desc{}; + block_scaled_desc.kind = OperationKind::kBlockScaledGemm; + block_scaled_desc.SFA.element = NumericTypeMap::kId; + block_scaled_desc.SFA.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFA.alignment = 128; + block_scaled_desc.SFA.log_extent_range = 32; + block_scaled_desc.SFA.log_stride_range = 32; + + block_scaled_desc.SFB.element = NumericTypeMap::kId; + block_scaled_desc.SFB.layout = LayoutTypeID::kRowMajor; + block_scaled_desc.SFB.alignment = 128; + block_scaled_desc.SFB.log_extent_range = 32; + block_scaled_desc.SFB.log_stride_range = 32; + + block_scaled_desc.SFMVecSize = 1; + block_scaled_desc.SFNVecSize = 1; + block_scaled_desc.SFKVecSize = SFVecSize; + + block_scaled_desc.SFD = make_TensorDescription(128); + block_scaled_desc.EpilogueSFVecSize = SFD_VectorSize; + + this->description_.block_scales = block_scaled_desc; + } + + ~GroupedBlockScaledGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockScaledArguments const& arguments) { + + if constexpr (epilogue_scalefactor_generation) { + fusion_args.block_scale_factor_ptr = static_cast(arguments.SFD); + fusion_args.norm_constant_ptr = static_cast(arguments.norm_constant); + } + + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockScaledArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockScaledArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + +template +class GroupedBlockwiseGemmUniversal3xOperation : public GroupedGemmOperation3xBase { +public: + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + using ElementSFA = typename Operator::ElementAccumulator; + using ElementSFB = typename Operator::ElementAccumulator; + + using TiledMma = typename Operator::CollectiveMainloop::TiledMma; + + GroupedBlockwiseGemmUniversal3xOperation(char const* name = "unknown_gemm") + : GroupedGemmOperation3xBase(name) { + + BlockScaleDescription blockwise_desc{}; + blockwise_desc.kind = OperationKind::kBlockwiseGemm; + blockwise_desc.SFA.element = NumericTypeMap::kId; + blockwise_desc.SFA.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFA{}.stride()) == 1 ? + LayoutTypeID::kColumnMajor : LayoutTypeID::kRowMajor; + blockwise_desc.SFA.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFA.log_extent_range = 32; + blockwise_desc.SFA.log_stride_range = 32; + + blockwise_desc.SFB.element = NumericTypeMap::kId; + blockwise_desc.SFB.layout = size<0,1>(typename CollectiveMainloop::InternalLayoutSFB{}.stride()) == 1 ? + LayoutTypeID::kRowMajor : LayoutTypeID::kColumnMajor; + blockwise_desc.SFB.alignment = CollectiveMainloop::AlignmentSFA; + blockwise_desc.SFB.log_extent_range = 32; + blockwise_desc.SFB.log_stride_range = 32; + + blockwise_desc.SFMVecSize = Operator::CollectiveMainloop::ScaleGranularityM; + blockwise_desc.SFNVecSize = Operator::CollectiveMainloop::ScaleGranularityN; + blockwise_desc.SFKVecSize = Operator::CollectiveMainloop::ScaleGranularityK; + + blockwise_desc.EpilogueSFVecSize = 0; + + this->description_.block_scales = blockwise_desc; + } + + ~GroupedBlockwiseGemmUniversal3xOperation() override = default; + + mutable CudaBuffer layout_SFA_device; + mutable CudaBuffer layout_SFB_device; + +protected: + template struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmGroupedArguments const& arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status + update_(FusionArgs& fusion_args, GroupedGemmBlockwiseArguments const& arguments) { + return GroupedGemmOperation3xBase::update_fusion_args(fusion_args, arguments); + } + }; + +public: + /// Returns success if the operation can proceed + Status can_implement([[maybe_unused]] void const* configuration_ptr, void const* arguments_ptr) + const override { + GroupedGemmBlockwiseArguments const* arguments = + static_cast(arguments_ptr); + OperatorArguments args; + auto status = update_arguments_(args, arguments); + if (status != Status::kSuccess) { + return status; + } + + status = Operator::can_implement(args); + return status; + } + + Status update_arguments_( + OperatorArguments& operator_args, + GroupedGemmBlockwiseArguments const* arguments) const { + Status status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, + *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.mainloop.ptr_SFA = + static_cast(arguments->SFA); + operator_args.mainloop.ptr_SFB = + static_cast(arguments->SFB); + + operator_args.mainloop.layout_SFA = + static_cast(this->layout_SFA_device.data()); + operator_args.mainloop.layout_SFB = + static_cast(this->layout_SFB_device.data()); + + return this->update_arguments_base(operator_args, *arguments); + } + + uint64_t get_device_workspace_size(void const* configuration_ptr, void const* arguments_ptr) + const override { + + OperatorArguments args; + auto status = + update_arguments_(args, static_cast(arguments_ptr)); + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + return size; + } + + /// Initializes the workspace + /// **** CAUTION **** + /// Must be called when lda, ldb, ldc, or ldd change. + /// The CUTLASS library stores the operations in a type- + /// erased manifest. Therefore, only this class knows + /// the type of strideA, strideB, strideC, and strideD. + /// Since grouped GEMM needs to allocate storage for + /// the strides on device, the concrete type of the stride + /// must be known in order to copy in the correct memory + /// layout on device. + Status initialize( + void const* configuration_ptr, + void* host_workspace, + void* device_workspace, + cudaStream_t stream = nullptr) const override { + + auto const& config = *static_cast(configuration_ptr); + auto status = this->initialize_strides(config); + if (status != Status::kSuccess) { + return status; + } + + auto num_groups = config.problem_count; + this->layout_SFA_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups); + this->layout_SFB_device = + CudaBuffer(sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups); + auto layout_SFA_host = std::vector(num_groups); + auto layout_SFB_host = std::vector(num_groups); + + for (int group_idx = 0; group_idx < num_groups; group_idx++) { + auto const& shape = config.problem_sizes_3x_host[group_idx]; + auto M = get<0>(shape); + auto N = get<1>(shape); + auto K = get<2>(shape); + + auto layout_SFA = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(M, N, K, 1)); + auto layout_SFB = CollectiveMainloop::ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(M, N, K, 1)); + layout_SFA_host[group_idx] = layout_SFA; + layout_SFB_host[group_idx] = layout_SFB; + } + + CUDA_CHECK(cudaMemcpy( + this->layout_SFA_device.data(), + layout_SFA_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFA) * num_groups, + cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy( + this->layout_SFB_device.data(), + layout_SFB_host.data(), + sizeof(typename CollectiveMainloop::InternalLayoutSFB) * num_groups, + cudaMemcpyHostToDevice)); + + Operator* op = new (host_workspace) Operator; + return status; + } + + /// **** CAUTION **** + /// initialize() must be called if lda, ldb, ldc, or ldd change. + Status run( + void const* arguments_ptr, + void* host_workspace, + void* device_workspace = nullptr, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + auto const& args = *static_cast(arguments_ptr); + + Status status = update_arguments_(operator_args, &args); + if (status != Status::kSuccess) { + return status; + } + + Operator* op = static_cast(host_workspace); + status = op->run(operator_args, device_workspace, stream, nullptr); + return status; + } +}; + + +} // namespace cutlass::library diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..e8bd77397f3b85cce2da2a7a8e447ab6ccb48aea --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/library_internal.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + + \brief CUTLASS Library is an object-oriented approach to managing operations implemented by CUTLASS. + + Generally, + + description - compile-time constant parameters used to instantiate an operation + + configuration - runtime parameters with computationally expensive initialization + + arguments - runtime parameters that may be passed to an initialized operation with low + computational overhead +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/arch_mappings.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct NumericTypeMap; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kVoid; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kB1; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kS64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU4; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE4M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE5M2; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M3; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE3M2; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFE2M1; +}; +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE8M0; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kFUE4M3; +}; + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kU64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF32; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF64; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF16; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF32; +}; + +template <> struct NumericTypeMap > { + static NumericTypeID const kId = NumericTypeID::kCF64; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kBF16; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kTF32; +}; + + + + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF8; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF6; +}; + +template <> struct NumericTypeMap { + static NumericTypeID const kId = NumericTypeID::kF4; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kInvalid; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAdd; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastBF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF16; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddGaussianComplex; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kXorPopc; +}; + + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddFastF32; +}; + +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddComplexFastF32; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct LayoutMap; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajor; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kRowMajor; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK2; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK4; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK16; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kColumnMajorInterleavedK64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kRowMajorInterleavedK64; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNHWC; +}; + +template <> struct LayoutMap { + static LayoutTypeID const kId = LayoutTypeID::kTensorNDHWC; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC32HW32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorNC64HW64; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC32RSK32; +}; + +template <> struct LayoutMap> { + static LayoutTypeID const kId = LayoutTypeID::kTensorC64RSK64; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct OpcodeClassMap; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSimt; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kTensorOp; +}; + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kBlockScaledOp; +}; + + +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ComplexTransformMap; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kNone; +}; + +template <> struct ComplexTransformMap { + static cutlass::library::ComplexTransform const kId = cutlass::library::ComplexTransform::kConjugate; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template struct ConvModeMap; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kCrossCorrelation; +}; + +template <> struct ConvModeMap { + static ConvModeID const kId = ConvModeID::kConvolution; +}; + + +template struct ConvKindMap; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kFprop; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kDgrad; +}; + +template <> struct ConvKindMap { + static ConvKind const kId = ConvKind::kWgrad; +}; + + +template struct IteratorAlgorithmMap; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kAnalytic; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kOptimized; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFixedChannels; +}; + +template <> struct IteratorAlgorithmMap { + static IteratorAlgorithmID const kId = IteratorAlgorithmID::kFewChannels; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +TensorDescription make_TensorDescription(int alignment = 1) { + TensorDescription desc; + + desc.element = NumericTypeMap::kId; + desc.layout = LayoutMap::kId; + desc.alignment = alignment; + desc.log_extent_range = int(sizeof(typename Layout::TensorCoord::Index) - 1) * 8; + desc.log_stride_range = int(sizeof(typename Layout::Stride::Index) - 1) * 8; + + return desc; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..76d8d0dfdb1aa6ed0324b9d6299b06ebf3f436d9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_2k_operation.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank 2K operation kinds (Syr2k, Her2k) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_2k.h" +#include "cutlass/gemm/kernel/default_rank_2k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + Rank2KOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRank2K; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::Rank2Kkernel::WarpCount::kM, + Operator::Rank2Kkernel::WarpCount::kN, + Operator::Rank2Kkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Rank2KOperation : public Rank2KOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + Rank2KOperation(char const *name = "unknown_rank_2k"): + Rank2KOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::Rank2KOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Rank2KOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..021f7f03fcc4449bdc2ef2c97e29fe0fead09a64 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/rank_k_operation.h @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Rank K operation kinds (Syrk, Herk) + in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/rank_k.h" +#include "cutlass/gemm/kernel/default_rank_k_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + RankKDescription description_; + +public: + + /// Constructor + RankKOperationBase(char const *name = "unknown_rank_k") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.rank_k_kind = RankKKind::kUniversal; + description_.fill_mode = kFillModeC; + description_.blas_mode = kBlasMode; + description_.num_ranks = kUpdateRank; + + description_.kind = OperationKind::kRankK; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::RankKkernel::WarpCount::kM, + Operator::RankKkernel::WarpCount::kN, + Operator::RankKkernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentA); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + description_.transform_B = ComplexTransformMap::kId; + } + + /// Returns the description of the SYRK operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class RankKOperation : public RankKOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementA; + using LayoutB = typename Operator::LayoutA; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static int const kUpdateRank = Operator::kUpdateRank; + static FillMode const kFillModeC = Operator::kFillModeC; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + RankKOperation(char const *name = "unknown_rank_k"): + RankKOperationBase(name) { + + this->description_.rank_k_kind = RankKKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + RankKConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->lda); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + RankKArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + RankKConfiguration const *configuration = + static_cast(configuration_ptr); + + RankKArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..6e948540e3f29dceace42b5e8ef3f91118c01b37 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reduction/reduction_operation.h @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for reduction operation in CUTLASS Library. +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/reduction/thread/reduction_operators.h" +#include "cutlass/reduction/device/reduce_split_k.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class ReductionOperation : public Operation { +public: + using Operator = Operator_; + + using ElementWorkspace = typename Operator::ElementWorkspace; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementOutput = typename Operator::ElementOutput; + + using ElementCompute = typename Operator::OutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + ReductionDescription description_; + +public: + + /// Constructor + ReductionOperation(char const *name = "unknown_reduction") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kReduction; + + description_.tile_description.threadblock_shape = make_Coord(Operator::Shape::kRow, Operator::Shape::kColumn, 1); + + description_.tile_description.math_instruction.instruction_shape = make_Coord(1, 1, 1); + description_.tile_description.math_instruction.element_accumulator = NumericTypeMap::kId; + description_.tile_description.math_instruction.opcode_class = OpcodeClassID::kSimt; + description_.tile_description.math_instruction.math_operation = MathOperationID::kAdd; + + description_.tile_description.minimum_compute_capability = 50; + description_.tile_description.maximum_compute_capability = 1024; + + description_.element_workspace = NumericTypeMap::kId; + description_.element_output = NumericTypeMap::kId; + description_.element_epilogue = NumericTypeMap::kId; + + } + + /// Returns the description of the Reduction operation + virtual OperationDescription const & description() const { + return description_; + } + + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + ReductionConfiguration const *configuration) { + + operator_args.problem_size = configuration->problem_size; + operator_args.partitions = configuration->partitions; + operator_args.partition_stride = configuration->partition_stride; + + operator_args.workspace = {nullptr, int(configuration->ldw)}; + operator_args.source = {nullptr, int(configuration->lds)}; + operator_args.destination = {nullptr, int(configuration->ldd)}; + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + ReductionArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::OutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.output = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::OutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.output = params; + } + else { + return Status::kErrorInvalidProblem; + } + + operator_args.workspace.reset(static_cast(const_cast(arguments->workspace))); + operator_args.source.reset(static_cast(const_cast(arguments->source))); + operator_args.destination.reset(static_cast(const_cast(arguments->destination))); + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + ReductionConfiguration const *configuration = + static_cast(configuration_ptr); + + ReductionArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + return Operator::get_workspace_size(args); + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + //std::cout << "initialize library::Reduction" << std::endl; + //print_operator_args(args); + return op->initialize(args, device_workspace, stream); + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + status = op->update(args, device_workspace); + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run library::Reduction" << std::endl; + //print_operator_args(args); + return op->run(stream); + } + + /// Call print_operator_args from the Reduction::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "Reduction::OperatorArguments" << std::endl + << " problem_size: " + << operator_args.problem_size << std::endl + << " partitions: " + << operator_args.partitions << std::endl + << " partition_stride: " + << operator_args.partition_stride << std::endl + << " epilogue (alpha, beta): " + << operator_args.output.alpha << ", " + << operator_args.output.beta << std::endl + << " workspace (ptr, stride): " + << operator_args.workspace.data() << ", " + << operator_args.workspace.stride(0) << std::endl + << " source (ptr, stride): " + << operator_args.source.data() << ", " + << operator_args.source.stride(0) << std::endl + << " destination (ptr, stride): " + << operator_args.destination.data() << ", " + << operator_args.destination.stride(0) << std::endl; + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..769da1c8515877536fd9b9fd72c836fd43ebd5d8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/block_scaled_gemm_reference_operation.h @@ -0,0 +1,453 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for block-scaled GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +namespace detail { +template +auto make_iterator(T* ptr) { + return cute::recast_ptr(ptr); +} +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ElementSFD_ = void, + typename LayoutSFD_ = LayoutC_, + int SFVecSize_ = 32, + int EpilogueSFVecSize_ = 0, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockScaledGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementSFD = ElementSFD_; + using LayoutSFD = LayoutSFD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + constexpr static int SFVecSize = SFVecSize_; + constexpr static int EpilogueSFVecSize = EpilogueSFVecSize_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockScaledGemmDescription description_; + +public: + + /// Constructor + BlockScaledGemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockScaledGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + description_.SFD = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFVecSize = SFVecSize; + description_.EpilogueSFVecSize = EpilogueSFVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.SFD.element) << to_string(description_.SFD.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockScaledGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + using Sm1xxBlockScaledConfig = cutlass::detail::Sm1xxBlockScaledConfig; + auto A = cute::make_tensor(detail::make_iterator(static_cast(args.A)), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFA(problem_shape_MNKL)); + + auto B = cute::make_tensor(detail::make_iterator(static_cast(args.B)), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), Sm1xxBlockScaledConfig::tile_atom_to_shape_SFB(problem_shape_MNKL)); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(detail::make_iterator(static_cast(args.C)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(detail::make_iterator(static_cast(nullptr)), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(detail::make_iterator(static_cast(args.D)), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + if constexpr (not is_same_v) { + + using Sm1xxBlockScaledOutputConfig= cutlass::detail::Sm1xxBlockScaledOutputConfig< + EpilogueSFVecSize + >; + + auto SfD = cute::make_tensor(detail::make_iterator(static_cast(args.SFD)), Sm1xxBlockScaledOutputConfig::tile_atom_to_shape_SFD(problem_shape_MNKL)); + + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD), Int, cutlass::reference::host::SfStrategy::SfDGen> + epilogue_params{alpha, beta, C, D, SfD, *(static_cast(args.norm_constant))}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + else { + // W/O SF generation + auto SfD = cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L))); // not used. + cutlass::reference::host::GettBlockScalingEpilogueParams< + ElementCompute, ElementAccumulator, ElementCompute, + decltype(C), decltype(D), decltype(SfD)> + epilogue_params{alpha, beta, C, D, SfD}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + } + + return Status::kSuccess; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm_tn(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementSFD_ = void, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + int SFVecSize = 32, + int EpilogueSFVecSize = SFVecSize, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_block_scaled_gemm(Manifest &manifest) { + /// + /// A is Row , B is Col + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + /// + /// A is Col , B is Row + /// + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); + manifest.append(new BlockScaledGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ElementSFD_, + cutlass::layout::RowMajor, + SFVecSize, + EpilogueSFVecSize, + ConvertOp_, + InnerProductOp_ + >); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..fd988f899f563acfc6f8003bdb49523bca51d6d9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/blockwise_gemm_reference_operation.h @@ -0,0 +1,807 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for blockwise/groupwise GEMM operation kinds in CUTLASS Library +*/ + + + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "cutlass/util/packed_stride.hpp" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gett.hpp" +#include "cutlass/detail/blockwise_scale_layout.hpp" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + typename LayoutSFA_, + typename ElementSFA_, + typename ElementB_, + typename LayoutB_, + typename LayoutSFB_, + typename ElementSFB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class BlockwiseGemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementSFA = ElementSFA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementSFB = ElementSFB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + BlockwiseGemmDescription description_; + +public: + + /// Constructor + BlockwiseGemmReferenceOperation(int SFMVecSize_, int SFNVecSize_, int SFKVecSize_) + : SFMVecSize(SFMVecSize_), SFNVecSize(SFNVecSize_), SFKVecSize(SFKVecSize_) { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kBlockwiseGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.SFA = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.SFB = make_TensorDescription(); + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + description_.SFMVecSize = SFMVecSize; + description_.SFNVecSize = SFNVecSize; + description_.SFKVecSize = SFKVecSize; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.SFA.element) << SFMVecSize << "x" << SFKVecSize << to_string(description_.SFA.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.SFB.element) << SFNVecSize << "x" << SFKVecSize << to_string(description_.SFB.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + using namespace cute; + + BlockwiseGemmArguments const &args = *static_cast(arguments); + + // Construct cute::Tensor A/B/C + + int M = args.problem_size.m(); + int N = args.problem_size.n(); + int K = args.problem_size.k(); + int L = args.batch_count; + + auto problem_shape_MNKL = cute::make_shape(M, N, K, L); + + auto alpha = *(static_cast(args.alpha)); + auto beta = *(static_cast(args.beta)); + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + auto stride_a = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + auto stride_b = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + auto stride_c = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + auto stride_d = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + using BlockwiseConfig = cutlass::detail::RuntimeBlockwiseScaleConfig<>; + auto A = cute::make_tensor(static_cast(args.A), + cute::make_layout(cute::make_shape(M, K, L), stride_a)); + auto SfA = make_tensor(static_cast(args.SFA), BlockwiseConfig::tile_atom_to_shape_SFA(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto B = cute::make_tensor(static_cast(args.B), + cute::make_layout(cute::make_shape(N, K, L), stride_b)); + auto SfB = make_tensor(static_cast(args.SFB), BlockwiseConfig::tile_atom_to_shape_SFB(problem_shape_MNKL, cute::make_tuple(SFMVecSize, SFNVecSize, SFKVecSize))); + + auto C = [&]() { + if constexpr (not is_same_v) { + return cute::make_tensor(static_cast(args.C), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + else { + return cute::make_tensor(static_cast(nullptr), + cute::make_layout(cute::make_shape(M, N, L), stride_c)); + } + }(); + + auto D = cute::make_tensor(static_cast(args.D), + cute::make_layout(cute::make_shape(M, N, L), stride_d)); + + cutlass::reference::host::GettBlockScalingMainloopParams + mainloop_params{A, SfA, B, SfB}; + + // W/O SF generation + cutlass::reference::host::GettEpilogueParams< + ElementCompute, ElementAccumulator, ElementAccumulator, ElementCompute, + decltype(C), decltype(D)> + epilogue_params{alpha, beta, C, D}; + + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + return Status::kSuccess; + } + +private: + int SFMVecSize; + int SFNVecSize; + int SFKVecSize; +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename ElementSFA_, + typename ElementB_, + typename ElementSFB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_blockwise_gemm(Manifest &manifest, int SFMVecSize, int SFNVecSize, int SFKVecSize) { + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::RowMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + manifest.append(new BlockwiseGemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementSFA_, + ElementB_, + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementSFB_, + ElementC_, + cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(SFMVecSize, SFNVecSize, SFKVecSize)); + + +} + +template +void initialize_blockwise_gemm_reference_operations_given_C_and_D(Manifest &manifest) { + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e4m3_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e4m3_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 1, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 1 , 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 64, 128, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 32, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 64, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 128, 256, 128); + make_blockwise_gemm< + float_e5m2_t /*A*/, float /*SFA*/, float_e5m2_t /*B*/, float /*SFB*/, + ElementC /*D*/, float /*Compute*/, float /*Accum*/, ElementD /*D*/ + >(manifest, 1, 256, 128); + +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..240fe18d16a27778bf75e0c02f99d251c096353f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/conv_reference_operation.h @@ -0,0 +1,636 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all CONV operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + Provider kProvider, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +struct ConvReferenceDispatcher; + +/// Dispatcher for Conv2d (partially specialized for kConvDim == 2) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 2, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv2dConfiguration const &config = + *static_cast(configuration); + + // TODO: make below code more general. It is fixed for NHWC now. + layout::TensorNHWC layout_a; + layout::TensorNHWC layout_b; + layout::TensorNHWC layout_c; + + layout_a.stride() = + make_Coord(int32_t(config.stride_a[0]), + int32_t(config.stride_a[1]), + int32_t(config.stride_a[2])); + + layout_b.stride() = + make_Coord(int32_t(config.stride_b[0]), + int32_t(config.stride_b[1]), + int32_t(config.stride_b[2])); + + layout_c.stride() = + make_Coord(int32_t(config.stride_c[0]), + int32_t(config.stride_c[1]), + int32_t(config.stride_c[2])); + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ElementC, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv2d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, layout_a}, + {ptr_B, layout_b}, + {ptr_C, layout_c}, + {ptr_D, layout_c}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +/// Dispatcher for Conv3d (partially specialized for kConvDim == 3) +template < + Provider kProvider, + cutlass::conv::Operator kConvolutionalOperator, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator, + typename ConvertOp, + typename InnerProductOp +> +struct ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + 3, + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp> { + + static Status dispatch( + void const *configuration, + ElementA *ptr_A, + ElementB *ptr_B, + ElementC *ptr_C, + ElementC *ptr_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr + ) { + + Conv3dConfiguration const &config = + *static_cast(configuration); + + ConvKind const conv_kind = ConvKindMap::kId; + + if (kProvider == Provider::kReferenceHost) { + cutlass::reference::host::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC , + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + return cutlass::reference::device::Conv3d< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >( + kConvolutionalOperator, + config.problem_size, + {ptr_A, config.layout_a(conv_kind)}, + {ptr_B, config.layout_b(conv_kind)}, + {ptr_C, config.layout_c(conv_kind)}, + {ptr_D, config.layout_c(conv_kind)}, + alpha, + beta, + stream + ); + } + return Status::kErrorNotSupported; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + cutlass::conv::Operator ConvolutionalOperator, + int ConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class ConvReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + static cutlass::conv::Operator const kConvolutionalOperator = ConvolutionalOperator; + static int const kConvDim = ConvDim; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + ConvDescription description_; + +public: + + /// Constructor + ConvReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = (kConvDim == 2 ? OperationKind::kConv2d : OperationKind::kConv3d); + description_.conv_kind = ConvKindMap::kId; + description_.conv_dim = kConvDim; + + // Tensor description + description_.A = make_TensorDescription(); + description_.B = make_TensorDescription(); + description_.C = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Iterator algorithm for convolution reference + description_.iterator_algorithm = IteratorAlgorithmID::kNone; + + // Compute capability for convolution reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind) + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + switch (kConvDim) { + case 2: + return sizeof(Conv2dConfiguration); + case 3: + return sizeof(Conv3dConfiguration); + default: + break; + } + + return 0; + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + ConvArguments const &args = *static_cast(arguments); + + ElementCompute alpha; + ElementCompute beta; + + alpha = *static_cast(args.alpha); + beta = *static_cast(args.beta); + + // TODO - respect pointer mode + + // Invoke 2D or 3D convolution + return detail::ConvReferenceDispatcher< + kProvider, + kConvolutionalOperator, + kConvDim, + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp + >::dispatch( + host_workspace, + static_cast(const_cast(args.A)), + static_cast(const_cast(args.B)), + static_cast(const_cast(args.C)), + static_cast(args.D), + alpha, + beta, + stream + ); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs Fprop reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_fprop(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kFprop, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Constructs Dgrad and Wgrad reference operators. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_backwards(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kDgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceHost, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new ConvReferenceOperation< + Provider::kReferenceDevice, + cutlass::conv::Operator::kWgrad, + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >); +#endif // !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) +} + +/// Six operators for the price of one. +template < + int kConvDim, + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_conv_all(Manifest &manifest) { + + make_conv_fprop< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_conv_backwards< + kConvDim, + ElementA_, LayoutA_, + ElementB_, LayoutB_, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..e07158b0602eef1d71cfdca95323b3da60553747 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/reference/gemm_reference_operation.h @@ -0,0 +1,543 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines reference operations for GEMM operation kinds in CUTLASS Library +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/util.h" +#include "library_internal.h" + +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + Provider Provider_, + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +class GemmReferenceOperation : public Operation { +public: + static Provider const kProvider = Provider_; + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + static cutlass::ComplexTransform const kTransformA = TransformA; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + static cutlass::ComplexTransform const kTransformB = TransformB; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using ElementD = ElementD_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementCompute = ElementCompute_; + using ElementAccumulator = ElementAccumulator_; + using ConvertOp = ConvertOp_; + using InnerProductOp = InnerProductOp_; + +protected: + + /// Storage for the name string + std::string name_; + + /// + GemmDescription description_; + +public: + + /// Constructor + GemmReferenceOperation() { + + // Basic information + description_.provider = kProvider; + description_.kind = OperationKind::kGemm; + description_.gemm_kind = GemmKind::kUniversal; + + // Tensor description + description_.A = make_TensorDescription(); + description_.transform_A = ComplexTransformMap::kId; + description_.B = make_TensorDescription(); + description_.transform_B = ComplexTransformMap::kId; + description_.C = make_TensorDescription(); + description_.D = make_TensorDescription(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + // Compute capability for gemm reference + description_.tile_description.minimum_compute_capability = + (kProvider == Provider::kReferenceDevice ? 50 : 0); + + description_.tile_description.maximum_compute_capability = 1024; + + // Procedural name + std::stringstream ss; + + ss << "gemm" + << "_reference_" << to_string(description_.provider) + << "_" << to_string(description_.A.element) << to_string(description_.A.layout) + << "_" << to_string(description_.B.element) << to_string(description_.B.layout) + << "_" << to_string(description_.C.element) << to_string(description_.C.layout) + << "_" << to_string(description_.tile_description.math_instruction.element_accumulator); + + name_ = ss.str(); + + description_.name = name_.c_str(); + + // Epilogue compute and accumulator type description + description_.element_epilogue = NumericTypeMap::kId; + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + } + + /// Returns the description of the GEMM operation + virtual OperationDescription const & description() const { + return description_; + } + + virtual Status can_implement( + void const *configuration, + void const *arguments) const { + + return Status::kSuccess; + } + + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(GemmUniversalConfiguration); + } + + virtual uint64_t get_device_workspace_size( + void const *configuration, + void const *arguments = nullptr) const { + + return 0; + } + + virtual Status initialize( + void const *configuration, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + std::memcpy(host_workspace, configuration, get_host_workspace_size(configuration)); + + return Status::kSuccess; + } + + virtual Status run( + void const *arguments, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + GemmUniversalConfiguration const &config = *static_cast(host_workspace); + GemmUniversalArguments const &args = *static_cast(arguments); + + TensorRefA ref_A{static_cast(const_cast(args.A)), LayoutA(int(config.lda))}; + TensorRefB ref_B{static_cast(const_cast(args.B)), LayoutB(int(config.ldb))}; + TensorRefC ref_C{static_cast(const_cast(args.C)), LayoutC(int(config.ldc))}; + TensorRefD ref_D{static_cast(args.D), LayoutC(int(config.ldd))}; + + if (kProvider == Provider::kReferenceHost) { + + cutlass::reference::host::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + else if (kProvider == Provider::kReferenceDevice) { + + cutlass::reference::device::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, + InnerProductOp + >( + config.problem_size, + *static_cast(args.alpha), + ref_A, + kTransformA, + ref_B, + kTransformB, + *static_cast(args.beta), + ref_C, + ref_D, + ElementAccumulator(), + ((config.mode == library::GemmUniversalMode::kBatched) ? config.batch_count : 1), + args.batch_stride_A, + args.batch_stride_B, + args.batch_stride_C, + args.batch_stride_D + ); + + return Status::kSuccess; + } + + return Status::kErrorNotSupported; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA_, + typename LayoutA_, + cutlass::ComplexTransform TransformA, + typename ElementB_, + typename LayoutB_, + cutlass::ComplexTransform TransformB, + typename ElementC_, + typename LayoutC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm(Manifest &manifest) { +#if !defined(CUTLASS_PROFILER_DISABLE_REFERENCE) + manifest.append(new GemmReferenceOperation< + Provider::kReferenceHost, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); + + manifest.append(new GemmReferenceOperation< + Provider::kReferenceDevice, + ElementA_, LayoutA_, TransformA, + ElementB_, LayoutB_, TransformB, + ElementC_, LayoutC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >); +#endif +} + +/// Helper to create NN, NT, TN, and TT GEMM layouts. +template < + typename ElementA_, cutlass::ComplexTransform TransformA, + typename ElementB_, cutlass::ComplexTransform TransformB, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_canonical_layouts(Manifest &manifest) { + + // M Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + // N Major outputs + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::ColumnMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::ColumnMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm< + ElementA_, cutlass::layout::RowMajor, TransformA, + ElementB_, cutlass::layout::RowMajor, TransformB, + ElementC_, cutlass::layout::RowMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + + +/// Helper to create TN and interleaved layouts GEMM layouts. +template < + int InterleaveK, + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_interleaved_layouts(Manifest &manifest) { + + make_gemm< + ElementA_, cutlass::layout::RowMajor, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::layout::ColumnMajor, cutlass::ComplexTransform::kNone, + ElementC_, cutlass::layout::ColumnMajor, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + +} + +/// Helper to real-valued GEMM with canonical layouts +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_real_canonical_layouts(Manifest &manifest) { + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +// Helper to create all complex transformation permutations +template < + typename ElementA_, + typename ElementB_, + typename ElementC_, + typename ElementCompute_, + typename ElementAccumulator_ = ElementCompute_, + typename ElementD_ = ElementC_, + typename ConvertOp_ = NumericConverter, + typename InnerProductOp_ = multiply_add +> +void make_gemm_complex_canonical_layouts(Manifest &manifest) { + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kNone, + ElementB_, cutlass::ComplexTransform::kConjugate, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); + + make_gemm_canonical_layouts< + ElementA_, cutlass::ComplexTransform::kConjugate, + ElementB_, cutlass::ComplexTransform::kNone, + ElementC_, + ElementCompute_, + ElementAccumulator_, + ElementD_, + ConvertOp_, + InnerProductOp_ + >(manifest); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp new file mode 100644 index 0000000000000000000000000000000000000000..01caa11e229ffd9109b0973dcca01064df448fa3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/sparse_gemm_operation_3x.hpp @@ -0,0 +1,504 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all GEMM operation kinds in CUTLASS Library. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/detail/collective.hpp" +#include "cutlass/array.h" +#include "cutlass/array_subbyte.h" +#include "cutlass/library/library.h" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" // StructuredSparseCompressor +#include "cutlass/transform/device/transform_universal_adapter.hpp" // TransformUniversalAdapter +#include "cutlass/util/packed_stride.hpp" // make_cute_packed_stride +#include "gemm_operation_3x.hpp" +#include "library_internal.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/mixed_dtype_utils.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cute/tensor.hpp" +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Limitation & Assumptions: +// 1. The tensor must be densely packed. That is, lda is k if the tensor is k-major, +// and lda is m if the tensor is m-major. +// 2. Circular buffer for tensorA and tensorE may have a less count compared to tensorB and others. +// This is because we can not get the problem_count information in the get_device_workspace_size(). +// But I can promise it will use at least 192MB memory if we enable circular buffer. +template +class SparseGemmUniversal3xOperation : public GemmOperation3xBase { +public: + + using Operator = Operator_; + using OperatorArguments = typename Operator::Arguments; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementD = typename Operator::ElementD; + using LayoutD = typename Operator::LayoutD; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using CollectiveMainloop = typename Operator::CollectiveMainloop; + using CollectiveEpilogue = typename Operator::CollectiveEpilogue; + using ThreadEpilogueOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + static constexpr bool IsRuntimeDataTypeA = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static constexpr bool IsRuntimeDataTypeB = cutlass::gemm::collective::detail::is_sm10x_runtime_f8f6f4(); + + static_assert((IsRuntimeDataTypeA && IsRuntimeDataTypeB) || + (!IsRuntimeDataTypeA && !IsRuntimeDataTypeB), + "ElementA and ElementB in a GEMM kernel should be both runtime or both static."); + + static constexpr bool IsRuntimeDataType = IsRuntimeDataTypeA && IsRuntimeDataTypeB; + + using ElementE = typename CollectiveMainloop::ElementE; + using LayoutE = typename CollectiveMainloop::LayoutE; + using SparseConfig = typename CollectiveMainloop::SparseConfig; + using LayoutATag = decltype(SparseConfig::deduce_layoutA_tag(typename CollectiveMainloop::LayoutA{})); + using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig>; + using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + cute::Shape, + ElementA, + LayoutATag, + SparseConfig, + typename Operator::ArchTag>; + + using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +public: + + /// Constructor + SparseGemmUniversal3xOperation(char const *name = "unknown_gemm"): + GemmOperation3xBase(name, GemmKind::kUniversal) {} + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, GemmUniversalConfiguration const *configuration) { + // NOTE: GemmUniversalConfiguration does not contain problem shapes or batch strides + // Do nothing here and construct kernel arguments in update_arguments_ instead + // We also cannot construct TMA descriptors without all the arguments available + + operator_args.mode = configuration->mode; + return Status::kSuccess; + } + + template + struct UpdateFusionArgs { + static Status update_(FusionArgs const& fusion_args, GemmUniversalArguments const &arguments) { + // If a custom EVT is instantiated then it is the users's responsibility + // to ensure alpha and beta are updated appropriately + return Status::kSuccess; + } + }; + + template + struct UpdateFusionArgs> { + static Status update_(FusionArgs& fusion_args, GemmUniversalArguments const &arguments) { + if (arguments.pointer_mode == ScalarPointerMode::kHost) { + fusion_args.alpha = *static_cast(arguments.alpha); + fusion_args.beta = *static_cast(arguments.beta); + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + + return Status::kSuccess; + } + else if (arguments.pointer_mode == ScalarPointerMode::kDevice) { + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = static_cast(arguments.alpha); + fusion_args.beta_ptr = static_cast(arguments.beta); + + return Status::kSuccess; + } + else { + return Status::kErrorInvalidProblem; + } + } + }; + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + GemmUniversalArguments const *arguments, + CompressorUtility const& compressor_utility, + void* device_a_compressed_ptr = nullptr, + void* device_e_ptr = nullptr) { + Status status = Status::kSuccess; + + status = UpdateFusionArgs::update_( + operator_args.epilogue.thread, *arguments); + if (status != Status::kSuccess) { + return status; + } + + operator_args.problem_shape = cute::make_shape( + arguments->problem_size.m(), + arguments->problem_size.n(), + arguments->problem_size.k(), + arguments->batch_count); + + // update arguments + + if constexpr (IsRuntimeDataType) { + using ArrayElementA = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementA; + using ArrayElementB = typename Operator::GemmKernel::CollectiveMainloop::ArrayElementB; + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + + std::unordered_map mapping = { + {RuntimeDatatype::kE4M3, cute::UMMA::MXF8F6F4Format::E4M3}, + {RuntimeDatatype::kE5M2, cute::UMMA::MXF8F6F4Format::E5M2}, + {RuntimeDatatype::kE3M2, cute::UMMA::MXF8F6F4Format::E3M2}, + {RuntimeDatatype::kE2M1, cute::UMMA::MXF8F6F4Format::E2M1} + }; + + auto iter_runtime_a = mapping.find(arguments->runtime_input_datatype_a); + auto iter_runtime_b = mapping.find(arguments->runtime_input_datatype_b); + + if (iter_runtime_a != mapping.end()) { + operator_args.mainloop.runtime_data_type_a = iter_runtime_a->second; + } else { + assert("invalid runtime argument for datatype A!"); + } + + if (iter_runtime_b != mapping.end()) { + operator_args.mainloop.runtime_data_type_b = iter_runtime_b->second; + } else { + assert("invalid runtime argument for datatype B!"); + } + + } + else { + operator_args.mainloop.ptr_A = static_cast(device_a_compressed_ptr); + operator_args.mainloop.ptr_B = static_cast(arguments->B); + } + operator_args.mainloop.ptr_E = static_cast(device_e_ptr); + operator_args.epilogue.ptr_C = static_cast(arguments->C); + operator_args.epilogue.ptr_D = static_cast(arguments->D); + + operator_args.mainloop.layout_a = compressor_utility.fill_layoutA_from_compressor(); + operator_args.mainloop.layout_e = compressor_utility.fill_layoutE_from_compressor(); + operator_args.mainloop.dB = cute::make_int_tuple_from( + arguments->ldb, arguments->batch_stride_B); + operator_args.epilogue.dC = cute::make_int_tuple_from( + arguments->ldc, arguments->batch_stride_C); + operator_args.epilogue.dD = operator_args.epilogue.dC; + + /* Query device SM count and max active clusters to pass onto the kernel as an argument, where needed */ + operator_args.hw_info.sm_count = arguments->sm_count; + if constexpr (!std::is_const_v) { + operator_args.scheduler.max_swizzle_size = arguments->swizzle_size; + } + + if constexpr (!std::is_const_v) { + using Enum_t = decltype(operator_args.scheduler.raster_order); + switch (arguments->raster_order) { + case RasterOrder::kAlongN: + operator_args.scheduler.raster_order = Enum_t::AlongN; + break; + case RasterOrder::kAlongM: + operator_args.scheduler.raster_order = Enum_t::AlongM; + break; + default: + operator_args.scheduler.raster_order = Enum_t::Heuristic; + } + } + + if constexpr (std::is_same_v) { + operator_args.scheduler.splits = arguments->split_k_slices; + } + + if constexpr (Operator::ArchTag::kMinComputeCapability >= 100) { + operator_args.hw_info.cluster_shape = dim3( + arguments->cluster_shape.m(), + arguments->cluster_shape.n(), + arguments->cluster_shape.k()); + operator_args.hw_info.cluster_shape_fallback = dim3( + arguments->cluster_shape_fallback.m(), + arguments->cluster_shape_fallback.n(), + arguments->cluster_shape_fallback.k()); + } + return status; + } + +public: + + /// Returns success if the operation can proceed + Status can_implement( + void const *configuration_ptr, void const *arguments_ptr) const override { + + GemmUniversalConfiguration const *configuration = + static_cast(configuration_ptr); + GemmUniversalArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + auto problem_shape_MNKL = cute::make_shape( + configuration->problem_size.m(), + configuration->problem_size.n(), + configuration->problem_size.k(), + configuration->batch_count); + + const int M = configuration->problem_size.m(); + const int N = configuration->problem_size.n(); + const int K = configuration->problem_size.k(); + const int L = configuration->batch_count; + using StrideA = typename CompressorUtility::StrideA; + auto dA = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + compressor_utility.set_problem_size(problem_shape_MNKL, dA); + auto status = update_arguments_(args, arguments, compressor_utility); + if (status != Status::kSuccess) { + return status; + } + + // can_implement rules may need access to problem shape + args.problem_shape = problem_shape_MNKL; + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + uint64_t get_host_workspace_size(void const *) const override { + // Memory to hold operator + host_op_workspace_size = sizeof(Operator); + + // Memory to hold result of `.structure_sparse_zero_mask_fill()` + tensor_a_size = compressor_utility.get_raw_tensor_A_bytes(); + + // NOTE: order here is the order of workspace partition + const uint64_t size = host_op_workspace_size + tensor_a_size; + + return size; + } + + /// Gets the device-side workspace + uint64_t get_device_workspace_size( + void const *configuration_ptr,void const *arguments_ptr) const override { + + OperatorArguments args; + auto status = update_arguments_( + args, static_cast(arguments_ptr), compressor_utility); + if (status != Status::kSuccess) { + return 0; + } + + typename Compressor::Arguments compress_arguments { + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {/*Empty Not Use*/}, + {/*Empty Not Use*/} }; + + // Size for one iteration + // For multi-iteration, will need to multiply result of this function w/ actual problem_count + tensor_ac_size = compressor_utility.get_compressed_tensor_A_bytes(); + tensor_e_size = compressor_utility.get_tensor_E_bytes(); + device_op_workspace_size = Operator::get_workspace_size(args); + device_compress_workspace_size = Compressor::get_workspace_size(compress_arguments); + + // NOTE: order here is the order of workspace partition + device_per_iter_workspace_size = device_op_workspace_size + device_compress_workspace_size + tensor_ac_size + tensor_e_size; + + return device_per_iter_workspace_size; + } + + /// Initializes the workspace + Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + return Status::kErrorInternal; + } + + Status initialize_with_profiler_workspace( + void const *configuration, + void *host_workspace, + void *device_workspace, + uint8_t **profiler_workspaces, + int problem_count_from_profiler, + cudaStream_t stream = nullptr) { + + iter_idx.resize(static_cast(configuration)->device_count, 0); + + // Set problem_count. + problem_count = problem_count_from_profiler; + + // * Host Ptr + auto* host_op_workspace_ptr = reinterpret_cast(host_workspace); + auto* host_a_raw_ptr = host_op_workspace_ptr + host_op_workspace_size; + + // * Construct Op + Operator *op = new (host_op_workspace_ptr) Operator; + + // * Device Ptr (1st iteration) + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iter1 = static_cast(device_workspace); + auto* device_op_workspace_ptr_iter1 = device_ptr_iter1; + auto* device_compressor_workspace_ptr_iter1 = device_op_workspace_ptr_iter1 + device_op_workspace_size; + auto* device_a_compressed_ptr_iter1 = device_compressor_workspace_ptr_iter1 + device_compress_workspace_size; + auto* device_e_ptr_iter1 = device_a_compressed_ptr_iter1 + tensor_ac_size; + + // * Device A Raw Ptr + auto* device_a_raw_ptr = profiler_workspaces[0]; + + // * Random fill 50% of TensorA w/ zero following the structured sparse requirement + CUDA_CHECK(cudaMemcpyAsync(host_a_raw_ptr, device_a_raw_ptr, tensor_a_size, cudaMemcpyDeviceToHost, stream)); + compressor_utility.structure_sparse_zero_mask_fill(host_a_raw_ptr, 2000); + CUDA_CHECK(cudaMemcpyAsync(device_a_raw_ptr, host_a_raw_ptr, tensor_a_size, cudaMemcpyHostToDevice, stream)); + + CUDA_CHECK(cudaGetLastError()); + + // * Compress DTensorA and get DTensorAC & DTensorE + cutlass::KernelHardwareInfo hw_info; + CUDA_CHECK(cudaGetDevice(&hw_info.device_id)); + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments{ + {compressor_utility.M, 0, compressor_utility.K, compressor_utility.L}, + {device_a_raw_ptr, + compressor_utility.dA, + device_a_compressed_ptr_iter1, + device_e_ptr_iter1}, + {hw_info} + }; + + cutlass::Status status {cutlass::Status::kSuccess }; + + Compressor compressor_op; + status = compressor_op.can_implement(arguments); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.initialize(arguments, device_compressor_workspace_ptr_iter1, stream); + if (status != Status::kSuccess) { + return status; + } + + status = compressor_op.run(stream); + if (status != Status::kSuccess) { + return status; + } + + // * Copy Iter1's DTensorAC DTensorE to each iteration's DTensorAC DTensorE + for (int iter_i = 1; iter_i < problem_count; iter_i++) { + // * Device AC E Ptr per iteration + // Device workspace : | iter1 | iter2 | iter3 | .. | iterx | + // iteri : op_workspace | tensor_ac | tensor_e + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_i; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + + CUDA_CHECK(cudaMemcpyAsync(device_a_compressed_ptr, device_a_compressed_ptr_iter1, tensor_ac_size, cudaMemcpyDeviceToDevice, stream)); + CUDA_CHECK(cudaMemcpyAsync(device_e_ptr, device_e_ptr_iter1, tensor_e_size, cudaMemcpyDeviceToDevice, stream)); + } + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + CUDA_CHECK(cudaGetLastError()); + + return Status::kSuccess; + } + + /// Runs the kernel + Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const override { + + OperatorArguments operator_args; + + + const auto device_index = static_cast(arguments_ptr)->device_index; + + auto* device_ptr_iteri = static_cast(device_workspace) + device_per_iter_workspace_size * iter_idx[device_index]; + auto* device_op_workspace_ptr = device_ptr_iteri; + auto* device_compressor_workspace_ptr = device_op_workspace_ptr + device_op_workspace_size; + auto* device_a_compressed_ptr = device_compressor_workspace_ptr + device_compress_workspace_size; + auto* device_e_ptr = device_a_compressed_ptr + tensor_ac_size; + iter_idx[device_index] = (iter_idx[device_index] + 1) % problem_count; + + Status status = update_arguments_(operator_args, static_cast(arguments_ptr), compressor_utility, device_a_compressed_ptr, device_e_ptr ); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + // We need to call initialize() since we have to rebuild TMA desc for every new set of args + status = op->run(operator_args, device_op_workspace_ptr, stream, nullptr, + static_cast(arguments_ptr)->use_pdl); + return status; + } + +private: + // Variables that must change in the const functions. + mutable CompressorUtility compressor_utility; + mutable int problem_count = 1; + mutable std::vector iter_idx; + + mutable uint64_t tensor_ac_size = 0; + mutable uint64_t tensor_e_size = 0; + mutable uint64_t tensor_a_size = 0; + mutable uint64_t host_op_workspace_size = 0; + mutable uint64_t device_compress_workspace_size = 0; + mutable uint64_t device_op_workspace_size = 0; + mutable uint64_t device_per_iter_workspace_size = 0; +}; +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::library + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..c95d238a81f825dbbeae689ec452467cc8ca3afa --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/symm_operation.h @@ -0,0 +1,382 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all Symm operation kinds (Symm, Hemm) + in CUTLASS Library. + + +*/ + +#pragma once +#include +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/symm.h" +#include "cutlass/gemm/kernel/default_symm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" +#include "cutlass/core_io.h" +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + SymmDescription description_; + +public: + + /// Constructor + SymmOperationBase(char const *name = "unknown_symm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.symm_kind = SymmKind::kUniversal; + description_.side_mode = kSideModeA; + description_.fill_mode = kFillModeA; + description_.blas_mode = kBlasMode; + + description_.kind = OperationKind::kSymm; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::SymmKernel::WarpCount::kM, + Operator::SymmKernel::WarpCount::kN, + Operator::SymmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.C = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + } + + /// Returns the description of the SYMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class SymmOperation : public SymmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + static BlasMode const kBlasMode = Operator::kBlasMode; + static SideMode const kSideModeA = Operator::kSideModeA; + static FillMode const kFillModeA = Operator::kFillModeA; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + SymmOperation(char const *name = "unknown_symm"): + SymmOperationBase(name) { + + this->description_.symm_kind = SymmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + SymmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldc = int(configuration->ldc); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + SymmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.ptr_C = arguments->C; + operator_args.ptr_D = arguments->D; + + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.batch_stride_C = arguments->batch_stride_C; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + SymmConfiguration const *configuration = + static_cast(configuration_ptr); + + SymmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + //std::cout << "initialize() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideModeA == SideMode::kLeft && + std::is_same::value) || + (kSideModeA == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + //std::cout << "run() library::SymmOperation" << std::endl; + //print_operator_args(args); + status = op->run(stream); + + return status; + } + + /// Call print_operator_args from the Conv2dOperation::initialize() + // to dump arguments passed on to cutlass operator for debugging + void print_operator_args(OperatorArguments &operator_args) const { + std::cout << "SymmOperation::OperatorArguments" << std::endl + << " problem_size:" << std::endl + << operator_args.problem_size << std::endl + << " epilogue (alpha, beta): " + << operator_args.epilogue.alpha << ", " + << operator_args.epilogue.beta << std::endl + << " ref_A (ptr, {stride}): " + << operator_args.ptr_A << ", {" + << operator_args.lda << "}" << std::endl + << " ref_B (ptr, {stride}): " + << operator_args.ptr_B << ", {" + << operator_args.ldb << "}" << std::endl + << " ref_C (ptr, {stride}): " + << operator_args.ptr_C << ", {" + << operator_args.ldc << "}" << std::endl + << " ref_D (ptr, {stride}): " + << operator_args.ptr_D << ", {" + << operator_args.ldd << "}" << std::endl; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h new file mode 100644 index 0000000000000000000000000000000000000000..d419723791ace5d90eb7955223be9db72bbc2c3c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/library/src/trmm_operation.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines operations for all TRMM operation kinds in CUTLASS Library. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/trmm.h" +#include "cutlass/gemm/kernel/default_trmm_universal.h" +#include "cutlass/gemm/kernel/trmm_universal.h" + +#include "cutlass/library/library.h" +#include "library_internal.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace library { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperationBase : public Operation { +public: + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +protected: + + /// + TrmmDescription description_; + +public: + + /// Constructor + TrmmOperationBase(char const *name = "unknown_trmm") { + + description_.name = name; + description_.provider = Provider::kCUTLASS; + description_.kind = OperationKind::kTrmm; + description_.trmm_kind = TrmmKind::kUniversal; + description_.side_mode = kSideMode; + description_.fill_mode = kFillMode; + description_.diag_type = kDiagType; + + description_.tile_description.threadblock_shape = make_Coord( + Operator::ThreadblockShape::kM, + Operator::ThreadblockShape::kN, + Operator::ThreadblockShape::kK); + + description_.tile_description.threadblock_stages = Operator::kStages; + + description_.tile_description.warp_count = make_Coord( + Operator::TrmmKernel::WarpCount::kM, + Operator::TrmmKernel::WarpCount::kN, + Operator::TrmmKernel::WarpCount::kK); + + description_.tile_description.math_instruction.instruction_shape = make_Coord( + Operator::InstructionShape::kM, + Operator::InstructionShape::kN, + Operator::InstructionShape::kK); + + description_.tile_description.math_instruction.element_accumulator = + NumericTypeMap::kId; + + description_.tile_description.math_instruction.opcode_class = + OpcodeClassMap::kId; + + description_.tile_description.math_instruction.math_operation = + MathOperationMap::kId; + + description_.tile_description.minimum_compute_capability = + ArchMap::kMin; + + description_.tile_description.maximum_compute_capability = + ArchMap::kMax; + + description_.A = make_TensorDescription(Operator::kAlignmentA); + description_.B = make_TensorDescription(Operator::kAlignmentB); + description_.D = make_TensorDescription(Operator::kAlignmentC); + description_.element_epilogue = NumericTypeMap::kId; + + description_.split_k_mode = SplitKMode::kNone; + description_.transform_A = ComplexTransformMap::kId; + } + + /// Returns the description of the TRMM operation + virtual OperationDescription const & description() const { + return description_; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TrmmOperation : public TrmmOperationBase { +public: + + using Operator = Operator_; + using ElementA = typename Operator::ElementA; + using LayoutA = typename Operator::LayoutA; + static SideMode const kSideMode = Operator::kSideMode; + static FillMode const kFillMode = Operator::kFillMode; + static DiagType const kDiagType = Operator::kDiagType; + using ElementB = typename Operator::ElementB; + using LayoutB = typename Operator::LayoutB; + using ElementC = typename Operator::ElementC; + using LayoutC = typename Operator::LayoutC; + using ElementAccumulator = typename Operator::ElementAccumulator; + using ElementCompute = typename Operator::EpilogueOutputOp::ElementCompute; + + using OperatorArguments = typename Operator::Arguments; + +public: + + /// Constructor + TrmmOperation(char const *name = "unknown_trmm"): + TrmmOperationBase(name) { + + this->description_.trmm_kind = TrmmKind::kUniversal; + } + +protected: + + /// Constructs the arguments structure given the configuration and arguments + static Status construct_arguments_( + OperatorArguments &operator_args, + TrmmConfiguration const *configuration) { + + //operator_args.mode = configuration->mode; + + operator_args.problem_size = configuration->problem_size; + operator_args.batch_count = configuration->batch_count; + + operator_args.lda = int(configuration->lda); + operator_args.ldb = int(configuration->ldb); + operator_args.ldd = int(configuration->ldd); + + return Status::kSuccess; + } + + /// Constructs the arguments structure given the configuration and arguments + static Status update_arguments_( + OperatorArguments &operator_args, + TrmmArguments const *arguments) { + + if (arguments->pointer_mode == ScalarPointerMode::kHost) { + typename Operator::EpilogueOutputOp::Params params( + *static_cast(arguments->alpha), + *static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else if (arguments->pointer_mode == ScalarPointerMode::kDevice){ + typename Operator::EpilogueOutputOp::Params params( + static_cast(arguments->alpha), + static_cast(arguments->beta) + ); + operator_args.epilogue = params; + } + else { + return Status::kErrorInvalidProblem; + } + + // update arguments + operator_args.ptr_A = arguments->A; + operator_args.ptr_B = arguments->B; + operator_args.batch_stride_A = arguments->batch_stride_A; + operator_args.batch_stride_B = arguments->batch_stride_B; + operator_args.ptr_D = arguments->D; + operator_args.batch_stride_D = arguments->batch_stride_D; + + if (arguments->use_pdl) { + return Status::kErrorNotSupported; + } + + return Status::kSuccess; + } + +public: + + /// Returns success if the operation can proceed + virtual Status can_implement( + void const *configuration_ptr, + void const *arguments_ptr) const { + + TrmmConfiguration const *configuration = + static_cast(configuration_ptr); + + TrmmArguments const *arguments = + static_cast(arguments_ptr); + + OperatorArguments args; + + Status status = construct_arguments_(args, configuration); + + if (status != Status::kSuccess) { + return status; + } + + status = update_arguments_(args, arguments); + + if (status != Status::kSuccess) { + return status; + } + + return Operator::can_implement(args); + } + + /// Gets the host-side workspace + virtual uint64_t get_host_workspace_size( + void const *configuration) const { + + return sizeof(Operator); + } + + /// Gets the device-side workspace + virtual uint64_t get_device_workspace_size( + void const *configuration_ptr, + void const *arguments_ptr = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return 0; + } + + uint64_t size = Operator::get_workspace_size(args); + + return size; + } + + /// Initializes the workspace + virtual Status initialize( + void const *configuration_ptr, + void *host_workspace, + void *device_workspace, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = construct_arguments_( + args, + static_cast(configuration_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = new (host_workspace) Operator; + + status = op->initialize(args, device_workspace, stream); + + return status; + } + + /// Runs the kernel + virtual Status run( + void const *arguments_ptr, + void *host_workspace, + void *device_workspace = nullptr, + cudaStream_t stream = nullptr) const { + + OperatorArguments args; + + Status status = update_arguments_( + args, + static_cast(arguments_ptr)); + + if (status != Status::kSuccess) { + return status; + } + + Operator *op = static_cast(host_workspace); + + bool need_swapped_matrices = (kSideMode == SideMode::kLeft && + std::is_same::value) || + (kSideMode == SideMode::kRight && + std::is_same::value); + if (need_swapped_matrices) { + status = op->update(args.swapped_matrices(), device_workspace); + } else { + status = op->update(args, device_workspace); + } + + if (status != Status::kSuccess) { + return status; + } + + status = op->run(stream); + + return status; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace library +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..5d500d9149bf645eadf8110d98612c40882d742c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/block_scaled_gemm_operation_profiler.h @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockScaledGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockScaledGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockScaledGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + DeviceAllocation *Computed_SFD{nullptr}; + DeviceAllocation *Reference_SFD{nullptr}; + DeviceAllocation *Norm_constant{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockScaledGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockScaledGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockScaledGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockScaledGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..c110de278cac640c1cedd8dd29d1b8ac09de81ef --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/blockwise_gemm_operation_profiler.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Blockscale Gemm Profiler +*/ + + + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class BlockwiseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + int64_t sf_vec_m{0}; + int64_t sf_vec_n{0}; + int64_t sf_vec_k{0}; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + // + // Methods + // + + /// Parses the problem + Status parse( + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::BlockwiseGemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::BlockwiseGemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *SFA{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *SFB{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::BlockwiseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + GemmWorkspace gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + BlockwiseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~BlockwiseGemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::BlockwiseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..683465f50cda19c8d505f2e66bcb60173d7e942d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv2d_operation_profiler.h @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv2dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv2dProblem { + + int64_t n, h, w, c, p, q, k, r, s; + int64_t groups; + int64_t pad_h, pad_w; + int64_t stride_h, stride_w; + int64_t dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv2dProblem::alpha, Conv2dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + void set_default_output_size() { + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * p * q), int(k), int(r * s * c / groups)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * h * w), int(c), int(k * r * s)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(r * s * c), int(n * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(r), int(s), int(c / groups)}; + case library::ConvKind::kDgrad: return {int(k), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *reordered_B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + // + // Methods + // + + Conv2dWorkspace() + : A(nullptr), + B(nullptr), + reordered_B(nullptr), + C(nullptr), + Computed(nullptr), + Reference(nullptr) {} + + // Set stride vector for tensor activations, filters, output + void set_stride_vector(Conv2dProblem const &problem, + library::ConvKind const &conv_kind, + library::LayoutTypeID const &layout_a, + library::LayoutTypeID const &layout_b, + library::LayoutTypeID const &layout_c) { + std::vector stride_activations; + std::vector stride_filters; + std::vector stride_output; + + // Strides for interleaved fprop + if (conv_kind == library::ConvKind::kFprop && + ((layout_a == library::LayoutTypeID::kTensorNC32HW32 && + layout_b == library::LayoutTypeID::kTensorC32RSK32 && + layout_c == library::LayoutTypeID::kTensorNC32HW32) || + (layout_a == library::LayoutTypeID::kTensorNC64HW64 && + layout_b == library::LayoutTypeID::kTensorC64RSK64 && + layout_c == library::LayoutTypeID::kTensorNC64HW64))) { + int interleave = + (layout_a == library::LayoutTypeID::kTensorNC32HW32) ? 32 : 64; + + stride_activations.push_back(int(problem.w) * interleave); + stride_activations.push_back(int(problem.w) * int(problem.h) * + interleave); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.k) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * interleave); + stride_filters.push_back(int(problem.k) * int(problem.s) * + int(problem.r) * interleave); + + stride_output.push_back(int(problem.q) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * interleave); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } else { + // Strides for the rest cases + stride_activations.push_back(int(problem.c)); + stride_activations.push_back(int(problem.w) * int(problem.c)); + stride_activations.push_back(int(problem.h) * int(problem.w) * + int(problem.c)); + + stride_filters.push_back(int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.s) * int(problem.c / problem.groups)); + stride_filters.push_back(int(problem.r) * int(problem.s) * + int(problem.c / problem.groups)); + + stride_output.push_back(int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.k)); + stride_output.push_back(int(problem.q) * int(problem.p) * + int(problem.k)); + } + + switch (conv_kind) { + case library::ConvKind::kFprop: + configuration.stride_a = stride_activations; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_output; + + break; + case library::ConvKind::kDgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_filters; + configuration.stride_c = stride_activations; + + break; + case library::ConvKind::kWgrad: + configuration.stride_a = stride_output; + configuration.stride_b = stride_activations; + configuration.stride_c = stride_filters; + + break; + default: + throw std::runtime_error( + "Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv2dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv2dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv2dOperationProfiler(); + + Conv2dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ac4abdef238b00f216053419620a60dfccfd5316 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/conv3d_operation_profiler.h @@ -0,0 +1,449 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for convolution + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/handle.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class Conv3dOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct Conv3dProblem { + + int64_t n, d, h, w, c, z, p, q, k, t, r, s; + int64_t pad_d, pad_h, pad_w; + int64_t stride_d, stride_h, stride_w; + int64_t dilation_d, dilation_h, dilation_w; + + std::vector alpha; + std::vector beta; + + library::SplitKMode split_k_mode; + int64_t split_k_slices; + + library::ConvModeID conv_mode; + + library::Provider eq_gemm_provider; + + // convolution with parallel interleaved reduction + // convolution epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (Conv3dProblem::alpha, Conv3dProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + // + // Methods + // + + /// Total number of bytes loaded + int64_t bytes(library::ConvDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::ConvDescription const &operation_desc) const; + + /// Infers output size from the input size, padding, stride, and dilation + void set_default_output_size() { + z = ((d + pad_d - t * dilation_d) / stride_d) + 1; + p = ((h + pad_h - r * dilation_h) / stride_h) + 1; + q = ((w + pad_w - s * dilation_w) / stride_w) + 1; + } + + // Returns equivalent gemm problem size for convolution + cutlass::gemm::GemmCoord eq_gemm_size(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return cutlass::gemm::GemmCoord(int(n * z * p * q), int(k), int(t * r * s * c)); + case library::ConvKind::kDgrad: return cutlass::gemm::GemmCoord(int(n * d * h * w), int(c), int(t * r * s * k)); + case library::ConvKind::kWgrad: return cutlass::gemm::GemmCoord(int(k), int(t * r * s * c), int(n * z * p * q)); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor A + std::vector extent_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kDgrad: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kWgrad: return {int(n), int(z), int(p), int(q), int(k)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor B + std::vector extent_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kDgrad: return {int(k), int(t), int(r), int(s), int(c)}; + case library::ConvKind::kWgrad: return {int(n), int(d), int(h), int(w), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns extent for tensor C + std::vector extent_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return {int(n), int(z), int(p), int(q), int(k)}; + case library::ConvKind::kDgrad: return {int(n), int(d), int(h), int(w), int(c)}; + case library::ConvKind::kWgrad: return {int(k), int(t), int(r), int(s), int(c)}; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix A + library::LayoutTypeID eq_gemm_layout_a(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kRowMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix B + library::LayoutTypeID eq_gemm_layout_b(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return library::LayoutTypeID::kColumnMajor; // TN Gemm + case library::ConvKind::kDgrad: return library::LayoutTypeID::kRowMajor; // TT Gemm + case library::ConvKind::kWgrad: return library::LayoutTypeID::kRowMajor; // NT Gemm + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns layout for equivalent gemm matrix C + library::LayoutTypeID eq_gemm_layout_c(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + // Gemm operator assumes column-major output + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return library::LayoutTypeID::kColumnMajor; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix A + int64_t eq_gemm_lda(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix B + int64_t eq_gemm_ldb(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: return eq_gemm_size(conv_kind).k(); + case library::ConvKind::kDgrad: return eq_gemm_size(conv_kind).n(); + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).n(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns leading dimension for equivalent gemm matrix C + int64_t eq_gemm_ldc(library::ConvKind const &conv_kind) const { + + switch (conv_kind) { + case library::ConvKind::kFprop: + case library::ConvKind::kDgrad: + case library::ConvKind::kWgrad: return eq_gemm_size(conv_kind).m(); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + }; + + /// Workspace used + struct Conv2dWorkspace { + + /// Conv device allocations + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + /// Library configuration and arguments for convolution operator + library::Conv3dConfiguration configuration; + library::ConvArguments arguments; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count; + + /// Buffer used for the cutlass conv2d operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// Host data buffers for host reference operation + /// host buffer for tensor + std::vector host_tensor_a; + + /// host buffer for tensor b + std::vector host_tensor_b; + + /// host buffer for tensor c + std::vector host_tensor_c; + + + // + // Methods + // + + Conv2dWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + + // Returns stride vector for tensor A + std::vector stride_a(library::ConvKind const &conv_kind) { + return { + configuration.layout_a(conv_kind).stride()[0], + configuration.layout_a(conv_kind).stride()[1], + configuration.layout_a(conv_kind).stride()[2], + configuration.layout_a(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor B + std::vector stride_b(library::ConvKind const &conv_kind) { + + return { + configuration.layout_b(conv_kind).stride()[0], + configuration.layout_b(conv_kind).stride()[1], + configuration.layout_b(conv_kind).stride()[2], + configuration.layout_b(conv_kind).stride()[3] + }; + } + + // Returns stride vector for tensor C + std::vector stride_c(library::ConvKind const &conv_kind) { + + return { + configuration.layout_c(conv_kind).stride()[0], + configuration.layout_c(conv_kind).stride()[1], + configuration.layout_c(conv_kind).stride()[2], + configuration.layout_c(conv_kind).stride()[3] + }; + } + }; + +protected: + + // + // Data members + // + + /// CONV problem obtained from problem space + Conv3dProblem problem_; + + /// Device memory allocations + Conv2dWorkspace conv_workspace_; + + /// CUTLASS parallel reduction operation to follow this* conv2d operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + Conv3dOperationProfiler(Options const &options); + + /// Destructor + virtual ~Conv3dOperationProfiler(); + + Conv3dProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Updates the arguments structure for the CUTLASS operator based on + /// the problem index. + void set_cutlass_operator_arguments_(int problem_idx = 0); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::ConvDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against host reference + bool verify_with_host_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against device reference + bool verify_with_device_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#if CUTLASS_ENABLE_CUDNN + + /// Verifies CUTLASS against cudnn reference + bool verify_with_cudnn_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +#endif //#if CUTLASS_ENABLE_CUDNN + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..873ba1abe03c05df29edc032ea3f1ffd2f19c3ee --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cublas_helpers.h @@ -0,0 +1,456 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuBLAS. +*/ + +#pragma once + +#if CUTLASS_ENABLE_CUBLAS +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/blas3.h" + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Converts a cuBLAS status to cutlass::Status +Status get_cutlass_status(cublasStatus_t cublas); + +/// Converts a cuBLAS status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cublasStatus_t cublas_status); + +/// Maps a CUTLASS tensor layout to a cuBLAS transpose operation +bool get_cublas_transpose_operation( + cublasOperation_t &operation, + library::LayoutTypeID layout, + library::ComplexTransform transform = library::ComplexTransform::kNone); + +/// Maps a CUTLASS numeric type to a cuBLAS data type enumeration +bool get_cublas_datatype(cublasDataType_t &data_type, library::NumericTypeID element_type); + +/// Gets the cublas algorithm given threadblock tile dimensions and math opcode class +cublasGemmAlgo_t get_cublas_gemm_algo( + int cta_m, + int cta_n, + int cta_k, + library::OpcodeClassID opcode_class); + +/// Returns a status if cuBLAS can satisfy a particular GEMM description +Status cublas_satisfies(library::GemmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular RankK description +Status cublas_satisfies(library::RankKDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular TRMM description +Status cublas_satisfies(library::TrmmDescription const &desc); + +/// Returns a status if cuBLAS can satisfy a particular SYMM/HEMM description +Status cublas_satisfies(library::SymmDescription const &desc); + +/// This is a helper class to create cublasHandle_t automatically on CublasCreate object creation and +/// to destroy cublasHandle_t on CublasCreate object destruction. +/// Additionally, it provides implicit cast from CublasCreate's object to cublasHandle_t's object +class CublasCreate { +private: + cublasHandle_t handle; + cublasStatus_t status; + +public: + CublasCreate() { + status = cublasCreate(&handle); + } + + ~CublasCreate() { + cublasDestroy(handle); + } + + /// Implicit cast CublasCreate object to cublasHandle_t + operator cublasHandle_t() const { return handle; } + + /// returns cublasStatus_t for handle creation + cublasStatus_t get_cublas_create_status() { return status; } +}; + +/// This is a helper class to create cublasLtHandle_t automatically on CublasLtCreate object creation and +/// to destroy cublasLtHandle_t on CublasLtCreate object destruction. +/// Additionally, it provides implicit cast from CublasLtCreate's object to cublasLtHandle_t's object +class CublasLtCreate { +private: + cublasLtHandle_t handle; + cublasStatus_t status; + +public: + CublasLtCreate() { + status = cublasLtCreate(&handle); + } + + ~CublasLtCreate() { + cublasLtDestroy(handle); + } + + /// Implicit cast CublasLtCreate object to cublasLtHandle_t + operator cublasLtHandle_t() const { return handle; } + + /// returns cublasLtStatus_t for handle creation + cublasStatus_t get_cublaslt_create_status() { return status; } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Selects one or more cuBLAS algorithms. +static void select_cublas_algorithms( + std::vector &algorithms, + Options const &options, + library::GemmDescription const &op_desc) { + + library::OpcodeClassID const & opcode_class = + op_desc.tile_description.math_instruction.opcode_class; + + switch (options.library.algorithm_mode) { + case AlgorithmMode::kMatching: + { + algorithms.push_back(get_cublas_gemm_algo( + op_desc.tile_description.threadblock_shape.m(), + op_desc.tile_description.threadblock_shape.n(), + op_desc.tile_description.threadblock_shape.k(), + opcode_class)); + break; + } + + case AlgorithmMode::kBest: + { + // Choose first enumerated mode. If none are enumerated, choose based on opcode class + // and evaluate all of them. + + if (options.library.algorithms.empty()) { + // Enumerate all algorithms + if (opcode_class == library::OpcodeClassID::kSimt) { + + for (int algo = CUBLAS_GEMM_DEFAULT; + algo <= CUBLAS_GEMM_ALGO23; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + else { + + for (int algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= CUBLAS_GEMM_ALGO15_TENSOR_OP; + ++algo) { + + algorithms.push_back(cublasGemmAlgo_t(algo)); + } + } + } + else { + // Use the listed algorithms + algorithms.reserve(options.library.algorithms.size()); + + for (int algo : options.library.algorithms) { + algorithms.push_back(reinterpret_cast(algo)); + } + } + + break; + } + + case AlgorithmMode::kDefault: + { + + // Use the library's default algorithm + algorithms.push_back((opcode_class == library::OpcodeClassID::kSimt ? + CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + break; + } + default: + { + break; + } + } +} + +/// Dispatcher to cublasGemmEx() +struct cublasGemmExDispatcher { + + // + // Data members + // + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + cublasGemmAlgo_t algo; + Status status; + + // + // Methods + // + + cublasGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_, + cublasGemmAlgo_t algorithm = CUBLAS_GEMM_DFALT + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/// Dispatcher to cublaslt kernels +// +struct cublasLtGemmExDispatcher { + + // + // Data members + // + library::GemmDescription const &op_desc; + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasOperation_t trans_B; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type = CUDA_R_32F; + + //cublasLt-specific data structures + cublasLtMatmulDesc_t operationDesc = NULL; + cublasLtMatrixLayout_t Adesc = NULL, Bdesc = NULL, Cdesc = NULL, Ddesc = NULL; + cublasLtMatmulPreference_t preference = NULL; + + //is set by call to get_cublaslt_algo() + cublasLtMatmulHeuristicResult_t heuristicResult_; + void *workspace = nullptr; + + Status status; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + // + // Methods + // + + cublasLtGemmExDispatcher( + library::GemmDescription const &op_desc, + library::GemmUniversalConfiguration configuration_, + library::GemmUniversalArguments arguments_ + ); + + /// Initialize the cublasLt variables + void initialize_cublaslt(); + + + /// Runs auto-tuning for the cublas heuristics + bool get_cublaslt_algo(cublasLtHandle_t handle, + AlgorithmMode algorithm_mode + ); + + /// Executes GEMM using these arguments + cublasStatus_t operator()(cublasLtHandle_t handle, cudaStream_t stream = nullptr); + + ~cublasLtGemmExDispatcher(){ + + // descriptors are no longer needed as all GPU work was already enqueued + if (preference) cublasLtMatmulPreferenceDestroy(preference); + if (Ddesc) cublasLtMatrixLayoutDestroy(Ddesc); + if (Cdesc) cublasLtMatrixLayoutDestroy(Cdesc); + if (Bdesc) cublasLtMatrixLayoutDestroy(Bdesc); + if (Adesc) cublasLtMatrixLayoutDestroy(Adesc); + if (operationDesc) cublasLtMatmulDescDestroy(operationDesc); + + if (workspace) { + cudaFree(workspace); + } + + } + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas rank k update kernels +struct cublasRankKDispatcher { + + // + // Data members + // + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + int num_ranks; //(rank-k or rank-2k) + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasRankKDispatcher( + library::RankKDescription const &op_desc, + library::RankKConfiguration configuration_, + library::RankKArguments arguments_ + ); + + /// Executes RankK using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublasTrmm() +struct cublasTrmmDispatcher { + + // + // Data members + // + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasOperation_t trans_A; + cublasSideMode_t side; + cublasFillMode_t uplo; + cublasDiagType_t diag; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_D; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + Status status; + + // + // Methods + // + + cublasTrmmDispatcher( + library::TrmmDescription const &op_desc, + library::TrmmConfiguration configuration_, + library::TrmmArguments arguments_ + ); + + /// Executes TRMM using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Dispatcher to cublas symm/hemm update kernels +struct cublasSymmDispatcher { + + // + // Data members + // + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + // cublas-specific data structures to fill cublas API call arguments + cublasSideMode_t side; + cublasFillMode_t uplo; + cudaDataType_t data_type_A; + cudaDataType_t data_type_B; + cudaDataType_t data_type_C; + cudaDataType_t compute_data_type; + +#if (__CUDACC_VER_MAJOR__ >= 11) + cublasComputeType_t compute_type; +#endif + + BlasMode blas_mode; //(symmetric or hermitian) + Status status; + + // + // Methods + // + + cublasSymmDispatcher( + library::SymmDescription const &op_desc, + library::SymmConfiguration configuration_, + library::SymmArguments arguments_ + ); + + /// Executes Symm using these arguments + cublasStatus_t operator()(cublasHandle_t handle); +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +} // namespace profiler +} // namespace cutlass + + +#endif // #if CUTLASS_ENABLE_CUBLAS diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..7ce9eea5a883fa4c5732f5d8aec120a99064bac0 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cudnn_helpers.h @@ -0,0 +1,590 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Helper functions for mapping CUTLASS concepts to cuDNN. + +*/ + +#pragma once +#if CUTLASS_ENABLE_CUDNN +#include +#include +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/library/library.h" +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Converts a cuDNN status to cutlass::Status +Status get_cutlass_status(cudnnStatus_t cudnn_status); + +/// Converts a cuDNN status to cutlass::profiler::Disposition +Disposition get_cutlass_disposition(cudnnStatus_t cudnn_status); + +/// Checks cudnnStatus_t converts to cutlas status and returns if Status::kSuccess o.w. throws exception +Status checkCudnnErr(cudnnStatus_t cudnn_status); + +/// Maps a CUTLASS conv mode to a cuDNN conv mode enumeration +bool get_cudnn_conv_mode(cudnnConvolutionMode_t &cudnn_conv_mode, conv::Mode conv_mode); + +/// Maps a CUTLASS layout type to a cuDNN data type enumeration +bool get_cudnn_layout(cudnnTensorFormat_t &cudnn_layout, library::LayoutTypeID layout); + +/// Maps a CUTLASS numeric type to a cuDNN data type enumeration +bool get_cudnn_datatype(cudnnDataType_t &cudnn_element_type, library::NumericTypeID element_type); + +/// Maps CUTLASS math OpcodeClassID and MathOperationID to cuDNN math_type +bool get_cudnn_mathtype(cudnnMathType_t &cudnn_math_type, library::ConvDescription const &conv_desc); + +/// Returns a status if cudnn can satisfy a particular Conv2d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv2dConfiguration const &configuration); + +/// Returns a status if cudnn can satisfy a particular Conv3d description +Status cudnn_satisfies(library::ConvDescription const &desc, library::Conv3dConfiguration const &configuration); + +/// Cudnn compute type seems to be hardcoded to float (To handle a possible cudnn issue) +float cast_cudnn_compute_type_to_float(library::NumericTypeID type, void const * src); + + +/// This is a helper class to create cudnnHandle_t automatically on CudnnCreate object creation and +/// to destroy cudnnHandle_t on CudnnCreate object destruction. +/// Additionally, it provides implicit cast from CudnnCreate's object to cudnnHandle_t's object +class CudnnCreate { +private: + cudnnHandle_t handle; + cudnnStatus_t status; + +public: + CudnnCreate() { + status = cudnnCreate(&handle); + } + + ~CudnnCreate() { + cudnnDestroy(handle); + } + + /// Implicit cast CudnnCreate object to cudnnHandle_t + operator cudnnHandle_t() const { return handle; } + + /// returns cudnnStatus_t for handle creation + cudnnStatus_t get_cudnn_create_status() { return status; } +}; + + +namespace detail { + +/// Dispatcher to cudnn convolution operators +struct cudnnConvDispatcher { + + // + // Data members + // + //library::Conv2dConfiguration configuration; + library::ConvArguments arguments; + library::ConvKind conv_kind; + + // cudnn-specific data structures to fill cudnn API call arguments + // cudnn activation, filter, and output descriptors + cudnnTensorDescriptor_t activation_desc; + cudnnFilterDescriptor_t filter_desc; + cudnnTensorDescriptor_t output_desc; + cudnnConvolutionDescriptor_t conv_desc; + + // cudnn datatypes + cudnnDataType_t data_type_activation; + cudnnDataType_t data_type_filter; + cudnnDataType_t data_type_output; + + // cudnn layouts + cudnnTensorFormat_t layout_activation; + cudnnTensorFormat_t layout_filter; + cudnnTensorFormat_t layout_output; + + // cudnn convolution mode + cudnnConvolutionMode_t conv_mode; + + // cudnn math type (tensorop, tensorop with conversion, simt) + cudnnMathType_t math_type; + + // cudnn compute data type + cudnnDataType_t compute_type; + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + float alpha; + float beta; + + // cudnn workspace + size_t workspace_size_in_bytes = 0; + cutlass::device_memory::allocation workspace; + + // select cudnn's implicit gemm precomputed algorithm with tensor operations + static cudnnConvolutionFwdAlgo_t const fprop_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; + static cudnnConvolutionBwdDataAlgo_t const dgrad_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1; + static cudnnConvolutionBwdFilterAlgo_t const wgrad_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1; + + Status status; + + // + // Methods + // + + // TODO: unify ctor cudnnConvDispatcher for conv2d and conv3d by unifying Conv2dConfiguration + + // ctor for conv2d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv2dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + // Get cudnn mathtype (cudnnMathType_t) + good = (good && get_cudnn_mathtype(math_type, op_desc)); + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + return; + } + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation, filter, and output descriptor + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + activation_desc, + layout_activation, + data_type_activation, + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.H, + configuration.problem_size.W + )); + + status = get_cutlass_status( + cudnnSetFilter4dDescriptor( + filter_desc, + data_type_filter, + layout_filter, + configuration.problem_size.K, + configuration.problem_size.C / configuration.problem_size.groups, + configuration.problem_size.R, + configuration.problem_size.S + )); + + status = get_cutlass_status( + cudnnSetTensor4dDescriptor( + output_desc, + layout_output, + data_type_output, + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.P, + configuration.problem_size.Q + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + + // ctor for conv3d + cudnnConvDispatcher( + library::ConvDescription const &op_desc, + library::Conv3dConfiguration configuration, + library::ConvArguments arguments_, + cudnnHandle_t handle + ): + //configuration(configuration_), + arguments(arguments_), + conv_kind(op_desc.conv_kind), + status(Status::kSuccess) { + + bool good = true; + + // Get cudnn datatype, layout, and convolution mode from library::ConvDescription + good = (good && get_cudnn_datatype(data_type_activation, op_desc.A.element)); + good = (good && get_cudnn_datatype(data_type_filter, op_desc.B.element)); + good = (good && get_cudnn_datatype(data_type_output, op_desc.C.element)); + + good = (good && get_cudnn_layout(layout_activation, op_desc.A.layout)); + good = (good && get_cudnn_layout(layout_filter, op_desc.B.layout)); + good = (good && get_cudnn_layout(layout_output, op_desc.C.layout)); + + good = (good && get_cudnn_conv_mode(conv_mode, configuration.problem_size.mode)); + + // cudnn compute type seems to be hardcoded to float (to handle a possible a cudnn issue) + alpha = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.alpha); + beta = cast_cudnn_compute_type_to_float(op_desc.element_epilogue, arguments.beta); + + good = (good && get_cudnn_datatype( + compute_type, + op_desc.tile_description.math_instruction.element_accumulator)); + + // Check cutlass Conv2d description has equivalent operator in cudnn + if (!good) { + status = Status::kErrorNotSupported; + } + + // Create convolution descriptor object + status = get_cutlass_status(cudnnCreateConvolutionDescriptor(&conv_desc)); + + // Configure convolution operator + std::vector padding {configuration.problem_size.pad_d, configuration.problem_size.pad_h, configuration.problem_size.pad_w}; + std::vector stride {configuration.problem_size.stride_d, configuration.problem_size.stride_h, configuration.problem_size.stride_w}; + std::vector dilation {configuration.problem_size.dilation_d, configuration.problem_size.dilation_h, configuration.problem_size.dilation_w}; + + status = get_cutlass_status( + cudnnSetConvolutionNdDescriptor( + conv_desc, + op_desc.conv_dim, + padding.data(), + stride.data(), + dilation.data(), + conv_mode, + compute_type + )); + + // Set groups + status = get_cutlass_status(cudnnSetConvolutionGroupCount(conv_desc, configuration.problem_size.groups)); + + // Create activation, filter, and output descriptor objects + status = get_cutlass_status(cudnnCreateTensorDescriptor(&activation_desc)); + status = get_cutlass_status(cudnnCreateFilterDescriptor(&filter_desc)); + status = get_cutlass_status(cudnnCreateTensorDescriptor(&output_desc)); + + // Set activation descriptor + std::vector activation_extent { + configuration.problem_size.N, + configuration.problem_size.C, + configuration.problem_size.D, + configuration.problem_size.H, + configuration.problem_size.W + }; + + std::vector activation_stride { + configuration.layout_activations.stride()[3], + 1, + configuration.layout_activations.stride()[2], + configuration.layout_activations.stride()[1], + configuration.layout_activations.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + activation_desc, + data_type_activation, + op_desc.conv_dim + 2, + activation_extent.data(), + activation_stride.data() + )); + + // Set filter descriptor + std::vector filter_extent { + configuration.problem_size.K, + configuration.problem_size.C, + configuration.problem_size.T, + configuration.problem_size.R, + configuration.problem_size.S + }; + + std::vector filter_stride { + configuration.layout_filters.stride()[3], + 1, + configuration.layout_filters.stride()[2], + configuration.layout_filters.stride()[1], + configuration.layout_filters.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetFilterNdDescriptor( + filter_desc, + data_type_filter, + layout_filter, + op_desc.conv_dim + 2, + filter_extent.data() + )); + + + // Set output descriptor + std::vector output_extent { + configuration.problem_size.N, + configuration.problem_size.K, + configuration.problem_size.Z, + configuration.problem_size.P, + configuration.problem_size.Q + }; + + std::vector output_stride { + configuration.layout_output.stride()[3], + 1, + configuration.layout_output.stride()[2], + configuration.layout_output.stride()[1], + configuration.layout_output.stride()[0] + }; + + status = get_cutlass_status( + cudnnSetTensorNdDescriptor( + output_desc, + data_type_output, + op_desc.conv_dim + 2, + output_extent.data(), + output_stride.data() + )); + + // Set math instruction to tensor op + status = get_cutlass_status( + cudnnSetConvolutionMathType(conv_desc, math_type)); + + // Initialize workspace + switch (conv_kind) { + case library::ConvKind::kFprop: + status = get_cutlass_status( + cudnnGetConvolutionForwardWorkspaceSize( + handle, + activation_desc, + filter_desc, + conv_desc, + output_desc, + fprop_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kDgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardDataWorkspaceSize( + handle, + filter_desc, + output_desc, + conv_desc, + activation_desc, + dgrad_algo, + &workspace_size_in_bytes + )); break; + case library::ConvKind::kWgrad: + status = get_cutlass_status( + cudnnGetConvolutionBackwardFilterWorkspaceSize( + handle, + activation_desc, + output_desc, + conv_desc, + filter_desc, + wgrad_algo, + &workspace_size_in_bytes + )); break; + + } + + workspace = cutlass::device_memory::allocation(workspace_size_in_bytes); + } + + /// Executes Conv2d operator from cudnn library + cudnnStatus_t operator()(cudnnHandle_t handle) { + + switch (conv_kind) { + case library::ConvKind::kFprop: + return cudnnConvolutionForward( + handle, + &alpha, + activation_desc, + activation(), + filter_desc, + filter(), + conv_desc, + fprop_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + output_desc, + arguments.D + ); + case library::ConvKind::kDgrad: + return cudnnConvolutionBackwardData( + handle, + &alpha, + filter_desc, + filter(), + output_desc, + output(), + conv_desc, + dgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + activation_desc, + arguments.D + ); + case library::ConvKind::kWgrad: + return cudnnConvolutionBackwardFilter( + handle, + &alpha, + activation_desc, + activation(), + output_desc, + output(), + conv_desc, + wgrad_algo, + workspace.get(), + workspace_size_in_bytes, + &beta, + filter_desc, + arguments.D + ); + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Activation Tensor + void const * activation() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.A; + case library::ConvKind::kDgrad : return arguments.C; + case library::ConvKind::kWgrad : return arguments.B; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Filter Tensor + void const *filter() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.B; + case library::ConvKind::kDgrad : return arguments.B; + case library::ConvKind::kWgrad : return arguments.C; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } + + // Returns Output Tensor + void const *output() const { + switch(conv_kind) { + case library::ConvKind::kFprop : return arguments.C; + case library::ConvKind::kDgrad : return arguments.A; + case library::ConvKind::kWgrad : return arguments.A; + default : throw std::runtime_error("Invalid Conv Operator (fprop, dgrad, wgrad)"); + } + } +}; + +} // namespace detail +///////////////////////////////////////////////////////////////////////////////////////////////// +#endif //#if CUTLASS_ENABLE_CUDNN +} // namespace profiler +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..be82245325cebb147e2c801965a52ece91395cb2 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/cutlass_profiler.h @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/manifest.h" +#include "cutlass/library/singleton.h" + +#include "options.h" +#include "operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// CUTLASS Profiler application +class CutlassProfiler { +private: + + // + // Data members + // + + /// Performance testbench options + Options options_; + + /// Entry points for each operation + OperationProfilerVector operation_profilers_; + +private: + + /// Prints usage + void print_usage_(std::ostream &); + + /// Prints usage + void print_options_(std::ostream &); + + /// Enumerates all operations + void enumerate_(); + + /// Profiles all operations + int profile_(); + +public: + + CutlassProfiler(Options const &options); + ~CutlassProfiler(); + + /// Invokes profiling operations + int operator()(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..98f1fdc3044501e456c927471b30d74b09eafd39 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/debug.h @@ -0,0 +1,56 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include + +//#define report(x) { std::cout << "\033[31m" << __FILE__ << ":" << __LINE__ << " " << x << "\033[0m" << std::endl; } +//#define report(x) {} + +// Enable/Disable Profiler debug prints +//#define DEBUG_PROFILER + +//RED 31m // profiler prints debug messages in red +//YELLOW 33m // ir prints debug messages in yellow + +#ifndef DEBUG_PROFILER +#define debugprof(...) +#else +#define debugprof(...) do { \ + printf("\033[33m[DEBUG PROF] %s:%d | ", __FILE__, __LINE__); \ + printf(__VA_ARGS__); \ + printf("\033[0m\n"); \ + } while (0) +#endif diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h new file mode 100644 index 0000000000000000000000000000000000000000..488b635c2ec233e3027303bbf15a34f375a438fd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_allocation.h @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Execution environment +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/library/library.h" +#include "cutlass/util/distribution.h" + +#include "enumerated_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Device memory allocation +class DeviceAllocation { +private: + + /// Data type of contained elements + library::NumericTypeID type_; + + /// Gets the stride between elements + size_t batch_stride_; + + /// Capacity in elements of device allocation + size_t capacity_; + + /// Pointer to device memory + void *pointer_; + + /// Layout type ID + library::LayoutTypeID layout_; + + /// Stride vector + std::vector stride_; + + /// Extent vector + std::vector extent_; + + /// Support allocating a 'batch' of non-overlapping tensors in contiguous memory + int batch_count_; + + /// Buffer holding TensorRef instance to recently allocated memory + std::vector tensor_ref_buffer_; + + /// The device ID where the allocation is made + int device_; + +public: + // + // Static member functions + // + + /// Determines the number of bytes needed to represent this numeric type + static size_t bytes(library::NumericTypeID type, size_t capacity); + + /// Returns the stride of a packed layout + static std::vector get_packed_layout( + library::LayoutTypeID layout_id, + std::vector const &extent); + + /// returns the capacity needed + static size_t construct_layout( + void *bytes, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector &stride); + + /// Returns true if two blocks have exactly the same value + static bool block_compare_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity); + + /// Returns true if two blocks have approximately the same value + static bool block_compare_relatively_equal( + library::NumericTypeID numeric_type, + void const *ptr_A, + void const *ptr_B, + size_t capacity, + double epsilon, + double nonzero_floor); + +public: + // + // Methods + // + + DeviceAllocation(); + + DeviceAllocation( + library::NumericTypeID type, + size_t capacity, + int device = -1); + + DeviceAllocation( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1, + int device = -1); + + ~DeviceAllocation(); + + DeviceAllocation &reset(); + + /// Allocates device memory of a given type and capacity + DeviceAllocation &reset(library::NumericTypeID type, size_t capacity); + + /// Allocates memory for a given layout and tensor + DeviceAllocation &reset( + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride = std::vector(), + int batch_count = 1); + + /// Returns a buffer owning the tensor reference + std::vector &tensor_ref() { + return tensor_ref_buffer_; + } + + bool good() const; + + /// Data type of contained elements + library::NumericTypeID type() const; + + /// Pointer to start of device memory allocation + void *data() const; + + /// Pointer to the first element of a batch + void *batch_data(int batch_idx) const; + + /// Gets the layout type + library::LayoutTypeID layout() const; + + /// Gets the stride vector + std::vector const & stride() const; + + /// Gets the extent vector + std::vector const & extent() const; + + /// Gets the number of adjacent tensors in memory + int batch_count() const; + + /// Gets the stride (in units of elements) between items + int64_t batch_stride() const; + + /// Gets the stride (in units of bytes) between items + int64_t batch_stride_bytes() const; + + /// Capacity of allocation in number of elements + size_t capacity() const; + + /// Capacity of allocation in bytes + size_t bytes() const; + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_device(int seed, Distribution dist); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_host(int seed, Distribution dist); + + /// Initializes a device allocation to a sequential distribution + void initialize_sequential_device(Distribution dist); + + /// Initializes a host allocation to a sequential distribution + void initialize_sequential_host(Distribution dist); + + /// Initializes a device allocation to a random distribution using cuRAND + void initialize_random_sparsemeta_device(int seed, int MetaSizeInBits); + + /// Initializes a host allocation to a random distribution using std::cout + void initialize_random_sparsemeta_host(int seed, int MetaSizeInBits); + + /// Uniformly fills a tensor with a value when provided o.w. zero + void fill_device(double value); + + /// Uniformly fills a host allocation with a value when provided o.w. zero + void fill_host(double value); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_device(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_from_host(void const *ptr); + + /// Copies from an equivalent-sized tensor in device memory + void copy_to_host(void *ptr); + + /// Writes a tensor to csv + void write_tensor_csv(std::ostream &out); + +private: + /// A wrapper that sets the device, performs malloc, and sets back + cudaError_t malloc(void** ptr, size_t size); +}; + +using DeviceAllocationList = std::list; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h new file mode 100644 index 0000000000000000000000000000000000000000..0443b340397426bfafc812c1a4b9179fc6af0de4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/device_context.h @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief +*/ + +#pragma once + +#include +#include + + +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" + +#include "options.h" +#include "device_allocation.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collection of allocations on the device +class DeviceContext { +public: + + // + // Type definitions + // + using AllocationMap = std::map; + +private: + // + // Data members + // + + /// Memory allocations that exist (owning) + DeviceAllocationList device_memory_; + + /// Non-owning set of named allocations + AllocationMap allocations_; + +public: + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_block( + Options const &options, + std::string const &name, + library::NumericTypeID type, + size_t capacity, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + size_t device_index); + + /// Allocates memory of a given type, capacity (elements), and name + DeviceAllocation *allocate_and_initialize_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Allocates memory for sparse meta data + DeviceAllocation *allocate_and_initialize_sparsemeta_tensor( + Options const &options, + std::string const &name, + library::NumericTypeID type, + library::LayoutTypeID layout_id, + library::NumericTypeID type_a, + std::vector const &extent, + std::vector const &stride, + int batch_count, + int seed_shift, + size_t device_index); + + /// Clears named allocations (but does not necessarily free memory) + void clear(); + + /// Frees all device memory allocations + void free(); + + /// Gets the allocation by name + DeviceAllocation &at(std::string const &name); + + size_t size() const; + + AllocationMap::iterator begin(); + AllocationMap::iterator end(); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h new file mode 100644 index 0000000000000000000000000000000000000000..897311c228ce76c4e8814ce996929561d44d2465 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/enumerated_types.h @@ -0,0 +1,169 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +#include +#include +#include +#include +#include "cutlass/library/library.h" + +#define TRACE(x) { std::cout << __FILE__ << ":" << __LINE__ << " " << x << std::endl; } + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +T from_string(std::string const &); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing how the performance testbench evaluates kernels. +enum class ExecutionMode { + kProfile, ///< regular verification and profiling + kDryRun, ///< no kernels are launched or workspaces allocated; used to assess what operators might be launched + kEnumerate, ///< no kernels launched or workspaces allocated; lists all operation kind and operations + kTrace, ///< executes a single device-side computation with no other kernel launches + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(ExecutionMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +ExecutionMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Library algorithm mode +enum class AlgorithmMode { + kMatching, ///< compare against best matching algorithm + kBest, ///< evaluate all library algorithms and report best + kDefault, ///< use the library's default algorithm option + kInvalid +}; + +/// Converts a ExecutionMode enumerant to a string +char const *to_string(AlgorithmMode mode, bool pretty = false); + +/// Parses a ExecutionMode enumerant from a string +template <> +AlgorithmMode from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Outcome of a performance test +enum class Disposition { + kPassed, + kFailed, // kernel itself reported an error + kNotRun, + kIncorrect, // kernel finished without a detected error, but result does not equal expected result + kNotVerified, + kInvalidProblem, + kNotSupported, + kInvalid +}; + +/// Converts a Disposition enumerant to a string +char const *to_string(Disposition disposition, bool pretty = false); + +/// Parses a Disposition enumerant from a string +template <> +Disposition from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates when to save +enum class SaveWorkspace { + kNever, + kIncorrect, + kAlways, + kInvalid +}; + +/// Converts a SaveWorkspace enumerant to a string +char const *to_string(SaveWorkspace save_option, bool pretty = false); + +/// Parses a SaveWorkspace enumerant from a string +template <> +SaveWorkspace from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Indicates the type of kernel argument +// ArgumentType can be both ScalarType or NumericType. Thus, enums kScalar and kNumeric +// 1) kScalar: e.g. of a Scalar ArgumentType is u32 is a Scalar type. +// Its c++ equivalent as "type name = initializer" is "u32 m = 32" +// 2) kNumeric: e.g. of a Numeric ArgumentType is NumericTypeID is a Numeric type. +// Its c++ equivalent as "type name = initializer" is "NumericTypeID numeric_type = u32" +enum class ArgumentTypeID { + kScalar, + kInteger, + kTensor, + kBatchedTensor, + kStructure, + kEnumerated, + kInvalid +}; + +/// Converts a ArgumentTypeID enumerant to a string +char const *to_string(ArgumentTypeID type, bool pretty = false); + +/// Parses a ArgumentTypeID enumerant from a string +template <> +ArgumentTypeID from_string(std::string const &str); + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Profiler typedefs +using ProviderVector = std::vector; +using DispositionMap = std::map; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Print vector for the report +template +std::ostream& operator<< (std::ostream& out, const std::vector& v) { + for (size_t i = 0; i < v.size(); ++i) { + out << to_string(v[i], true) << (i + 1u != v.size() ? "," : ""); + } + return out; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..faf317152473cac6dc62ecf8970cd1acfb2c1622 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gemm_operation_profiler.h @@ -0,0 +1,333 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Gemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "reduction_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct GemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGemm}; + + /// For profiling purposes + std::vector problem_sizes; + std::vector> leading_dims; + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int64_t m{16}; + int64_t n{16}; + int64_t k{16}; + + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + + int64_t lda{0}; + int64_t ldb{0}; + int64_t ldc{0}; + std::vector alpha; + std::vector beta; + + cutlass::library::SplitKMode split_k_mode{library::SplitKMode::kNone}; + int split_k_slices{1}; + int batch_count{1}; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + + // gemm with parallel interleaved reduction + // gemm epilogue (alpha, beta) = (1.0, 0.0) + // reduction epilogue (alpha, beta) = (GemmProblem::alpha, GemmProblem::beta) + std::vector alpha_one; + std::vector beta_zero; + + bool use_pdl{false}; + + bool enable_sm90_mixed_dtype_shuffle_test{false}; + + // + // Methods + // + + /// Parses the problem + Status parse( + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + int64_t bytes_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + int64_t flops_with_problem_shape( + library::GemmDescription const &operation_desc, + gemm::GemmCoord const &problem_shape) const; + + /// Total number of bytes loaded + int64_t bytes(library::GemmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GemmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct GemmWorkspace { + + DeviceAllocation *A{nullptr}; + DeviceAllocation *B{nullptr}; + DeviceAllocation *C{nullptr}; + DeviceAllocation *Computed{nullptr}; + DeviceAllocation *Reference{nullptr}; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + int problem_count{1}; + + library::GemmUniversalConfiguration configuration; + library::GemmUniversalArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + /// Library configuration and arguments for reduction operator + library::ReductionConfiguration reduction_configuration; + library::ReductionArguments reduction_arguments; + + /// Buffer used for the cutlass reduction operations' host workspace + std::vector reduction_host_workspace; + + /// For mixed input dtype kernels + DeviceAllocation *Scale{nullptr}; // Scale tensor + DeviceAllocation *Zero{nullptr}; // Zero tensor + DeviceAllocation *dequantized_AB{nullptr}; // Dequantized A or B tensor for verification + DeviceAllocation *encoded_AB{nullptr}; // Encoded A or B in int4 x fp8 or shuffle + DeviceAllocation *packed_Scale{nullptr}; // Packed scale for int4 * fp8 + + cudaStream_t stream; + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + GemmProblem problem_; + + /// Device memory allocations + std::vector gemm_workspace_; + + /// CUTLASS parallel reduction operation to follow this* gemm operation + library::Operation const *reduction_op_; + +public: + // + // Methods + // + + /// Ctor + GemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~GemmOperationProfiler(); + + GemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GemmWorkspace &gemm_workspace, + gemm::GemmCoord const &problem_shape, + std::array const &leading_dim, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration according to flexible user setups + void update_result_( + PerformanceResult &result, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space, + gemm::GemmCoord const &problem_shape, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::GemmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + GemmWorkspace &gemm_workspace); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Initialize reduction problem dimensions and library::Operation + bool initialize_reduction_configuration_( + library::Operation const *operation, + ProblemSpace::Problem const &problem); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h new file mode 100644 index 0000000000000000000000000000000000000000..154045295d6443d930ba53387366f4b8abe408a4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/gpu_timer.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct GpuTimer { + + cudaEvent_t events[2]; + + // + // Methods + // + + GpuTimer(); + + GpuTimer(GpuTimer const&) = delete; + + GpuTimer(GpuTimer &&gpu_timer) noexcept; + + ~GpuTimer(); + + /// Records a start event in the stream, the flag is for cudaEventRecordWithFlags + void start(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream, the flag is for cudaEventRecordWithFlags + void stop(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Records a stop event in the stream and synchronizes on the stream, the flag is for cudaEventRecordWithFlags + void stop_and_wait(cudaStream_t stream = nullptr, unsigned int flag = cudaEventRecordDefault); + + /// Returns the duration in milliseconds + double duration(int iterations = 1) const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..62d47990584cbb984935a00a267cff15dbb4f4e5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/grouped_gemm_operation_profiler.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/* \file + \brief GroupedGemm Profiler +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "device_context.h" +#include "operation_profiler.h" +#include "options.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class GroupedGemmOperationProfiler : public OperationProfiler { +public: + /// Problem structure obtained from problem space + struct GroupedGemmProblem { + + cutlass::library::GemmUniversalMode mode{library::GemmUniversalMode::kGrouped}; + + std::vector problem_sizes; + std::vector> problem_sizes_3x; + + /// For exploration purposes + std::vector> preferred_clusters; + std::vector> fallback_clusters; + std::vector raster_orders; + std::vector swizzle_sizes; + + int cluster_m{1}; + int cluster_n{1}; + int cluster_k{1}; + int cluster_m_fallback{1}; + int cluster_n_fallback{1}; + int cluster_k_fallback{1}; + + std::vector lda{0}; + std::vector ldb{0}; + std::vector ldc{0}; + + std::vector alpha; + std::vector beta; + + cutlass::library::RasterOrder raster_order{cutlass::library::RasterOrder::kHeuristic}; + int swizzle_size{1}; + + cutlass::library::RuntimeDatatype runtime_input_datatype_a{}; + cutlass::library::RuntimeDatatype runtime_input_datatype_b{}; + + bool use_pdl{false}; + + /// Parses the problem + Status parse( + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + int64_t m(int group_idx) const { return problem_sizes[group_idx].m(); }; + int64_t n(int group_idx) const { return problem_sizes[group_idx].n(); }; + int64_t k(int group_idx) const { return problem_sizes[group_idx].k(); }; + + /// Total number of bytes loaded + int64_t bytes(library::GroupedGemmDescription const& operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::GroupedGemmDescription const& operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult& result, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + }; + + struct BlockScalingWorkspace { + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector SFA_ptr_array_device; + std::vector SFB_ptr_array_device; + std::vector SFC_ptr_array_device; + std::vector SFD_ptr_array_device; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector SFA_ptr_array_host; + std::vector SFB_ptr_array_host; + std::vector SFC_ptr_array_host; + std::vector SFD_ptr_array_host; + std::vector SFD_reference_ptr_array_host; + + // matrix wide constant, not per-batch or per-group + DeviceAllocation* norm_constant; + }; + + // workspace contains the allocated blocks, arguments just contain the raw + // pointers + struct GroupedGemmWorkspace { + + // host vector (per L2 workspace) of device vectors (per group) of device pointers + std::vector A_ptr_array_device; + std::vector B_ptr_array_device; + std::vector C_ptr_array_device; + std::vector D_ptr_array_device; + std::vector reference_ptr_array_host; + + // host vector (per group) of device tensors + // (where each batch of device allocation is for a L2 workspace) + std::vector A_ptr_array_host; + std::vector B_ptr_array_host; + std::vector C_ptr_array_host; + std::vector D_ptr_array_host; + + /// Number of copies of the problem workspace which are visited sequentially during + /// profiling to avoid camping in the last level cache. + /// *NOT* the number of groups in the grouped GEMM (we use `num_groups` in the profiler) + int problem_count{1}; + + DeviceAllocation* problem_sizes_array_device{nullptr}; + DeviceAllocation* problem_sizes_3x_array_device{nullptr}; + DeviceAllocation* lda_array_device{nullptr}; + DeviceAllocation* ldb_array_device{nullptr}; + DeviceAllocation* ldc_array_device{nullptr}; + DeviceAllocation* ldd_array_device{nullptr}; + + std::optional block_scales; + + library::GemmGroupedConfiguration configuration; + library::GroupedGemmBlockScaledArguments arguments; + + std::vector host_workspace; + DeviceAllocation device_workspace; + + cudaStream_t stream; + }; + +private: + void init_arguments(Options const& options) { + auto& arguments = gemm_workspace_.arguments; + // these get updated in each profiler run to ensure L2 cycling + arguments.ptr_A = gemm_workspace_.A_ptr_array_device[0]->data(); + arguments.ptr_B = gemm_workspace_.B_ptr_array_device[0]->data(); + arguments.ptr_C = gemm_workspace_.C_ptr_array_device[0]->data(); + arguments.ptr_D = gemm_workspace_.D_ptr_array_device[0]->data(); + + arguments.alpha = problem_.alpha.data(); + arguments.beta = problem_.beta.data(); + arguments.pointer_mode = library::ScalarPointerMode::kHost; + arguments.lda = static_cast(gemm_workspace_.lda_array_device->data()); + arguments.ldb = static_cast(gemm_workspace_.ldb_array_device->data()); + arguments.ldc = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.ldd = static_cast(gemm_workspace_.ldc_array_device->data()); + arguments.problem_sizes = + static_cast(gemm_workspace_.problem_sizes_array_device->data()); + arguments.problem_sizes_3x = static_cast*>( + gemm_workspace_.problem_sizes_3x_array_device->data()); + gemm_workspace_.arguments.problem_sizes_3x_host = problem_.problem_sizes_3x.data(); + gemm_workspace_.arguments.problem_count = problem_.problem_sizes.size(); + gemm_workspace_.arguments.cluster_shape = {int(problem_.cluster_m), int(problem_.cluster_n), int(problem_.cluster_k)}; + gemm_workspace_.arguments.cluster_shape_fallback = {int(problem_.cluster_m_fallback), int(problem_.cluster_n_fallback), int(problem_.cluster_k_fallback)}; + + /* Query device SM count to pass onto the kernel as an argument, where needed */ + arguments.sm_count = options.device.get_sm_count(0); + if (is_block_scaled) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + arguments.SFD = block_scaled_ws.SFD_ptr_array_device[0]->data(); + arguments.norm_constant = block_scaled_ws.norm_constant->data(); + } + else if (is_blockwise) { + auto& block_scaled_ws = gemm_workspace_.block_scales.value(); + arguments.SFA = block_scaled_ws.SFA_ptr_array_device[0]->data(); + arguments.SFB = block_scaled_ws.SFB_ptr_array_device[0]->data(); + } + } + +protected: + /// GEMM problem obtained from problem space + GroupedGemmProblem problem_; + + /// Device memory allocations + GroupedGemmWorkspace gemm_workspace_; + + bool is_block_scaled{false}; + bool is_blockwise{false}; + +public: + GroupedGemmOperationProfiler(Options const& options); + + virtual ~GroupedGemmOperationProfiler(); + + GroupedGemmProblem const& problem() const { return problem_; } + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream& out) const; + + /// Prints examples + virtual void print_examples(std::ostream& out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + + /// Measures performance results + virtual bool profile( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +protected: + /// Initializes the performance result + void initialize_result_( + PerformanceResult& result, + Options const& options, + library::GroupedGemmDescription const& operation_desc, + ProblemSpace const& problem_space); + + /// Update workspace configuration according to flexible user setups + void update_workspace_( + GroupedGemmWorkspace &gemm_workspace, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + cutlass::library::RasterOrder const &raster_order, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Update performance result configuration for exploration parameters + void update_workspace_and_result_( + GroupedGemmWorkspace &gemm_workspace, + PerformanceResult &result, + ProblemSpace const &problem_space, + cutlass::library::RasterOrder const &raster_order, + std::array const &preferred_cluster, + std::array const &fallback_cluster, + int swizzle_size, + bool is_dynamic_cluster_enabled); + + /// Verifies CUTLASS against host and device references + bool verify_with_reference_( + Options const& options, + PerformanceReport& report, + DeviceContext& device_context, + library::Operation const* operation, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem, + cutlass::library::NumericTypeID element_A, + cutlass::library::NumericTypeID element_B); + + /// Method to profile a CUTLASS Operation + Status profile_cutlass_( + PerformanceResult& result, + Options const& options, + library::Operation const* operation, + void* arguments, + void* host_workspace, + void* device_workspace) override; + + /// Method to profile a CUTLASS Operation for the best configuration for a fixed shape + bool profile_cutlass_for_fixed_shape_( + Options const& options, + library::Operation const* operation, + ProblemSpace const& problem_space); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..446ef2c16739b28aaf038ca62bad6e3cdf667813 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/operation_profiler.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include +#include +#include +#include + +// CUTLASS includes +#include "cutlass/trace.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "performance_result.h" +#include "performance_report.h" +#include "problem_space.h" +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class OperationProfiler { +public: + + +protected: + // + // Data members + // + + /// Top-level operation kind + library::OperationKind kind_; + + /// Human readable description + std::string description_; + + /// Arguments parsed from command line + ArgumentDescriptionVector arguments_; + + /// List of providers used to verify and compare each result + ProviderVector verification_providers_; + + /// Model performance result initialized by the operation profiler with workload statistics + /// and reasonable default state. + PerformanceResult model_result_; + + /// Performance result vector constructed by profiling the operation + PerformanceResultVector results_; + +public: + + // + // Methods + // + + /// Ctor + OperationProfiler(); + + OperationProfiler( + Options const &options, + library::OperationKind kind, + ArgumentDescriptionVector const &arguments = ArgumentDescriptionVector(), + ProviderVector const & verification_providers = ProviderVector()); + + /// Destructor + virtual ~OperationProfiler(); + + /// Obtains the operation kind + library::OperationKind kind() const { return kind_; } + + /// Gets the schema description + std::string const &description() const; + + /// Returns a reference to the arguments + ArgumentDescriptionVector const &arguments() const { return arguments_; } + +public: + + // + // Basic overrides + // + + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const =0; + + /// Entry point to profile all operations in the manifest + virtual int profile_all( + Options const &options, + library::Manifest const &manifest, + DeviceContext &device_context); + +public: + + // + // Operation-specific phases of verification and profiling + // + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem) = 0; + +public: + + // + // Static helpers + // + + /// Sleep for a given duration in ms + static void sleep(int sleep_duration); + + /// Returns true if the current operation description satisfies the problem space + static bool satisfies( + library::OperationDescription const &op_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Compares tensors for equality + static Disposition compare_tensors( + Options const &options, + DeviceAllocation &experimental, + DeviceAllocation &reference, + int64_t count = 0); + + static void save_workspace( + DeviceContext &device_context, + Options const &options, + library::OperationDescription const &desc, + library::Provider provider, + library::Provider verification_provider = library::Provider::kInvalid); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + std::string const &value); + + /// Helper to set a performance result member + static void set_argument( + PerformanceResult &result, + char const *name, + ProblemSpace const &problem_space, + int64_t value); + +protected: + + /// Sets operation description + static void initialize_result_( + PerformanceResult &result, + library::OperationDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Method to profile an initialized CUTLASS operation + virtual Status profile_cutlass_( + PerformanceResult &result, + Options const &options, + library::Operation const *operation, + void *arguments, + void *host_workspace, + void *device_workspace); + + /// Profiles the GPU kernel launched in `func` running simultaneously on all + /// requested devices. + Status profile_kernel_w_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + std::vector const& streams); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + + /// Profiles the GPU kernel launched in `func` on the `stream` + Status profile_kernel_no_cuda_graphs_( + PerformanceResult& result, + Options const& options, + std::function const& func, + cudaStream_t stream = nullptr); + +private: + /// finds string matches filter_string in operation_name + bool find_string_matches_( + std::string const &filter_string, + std::string const &operation_name); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector of owning operation profilers +using OperationProfilerVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h new file mode 100644 index 0000000000000000000000000000000000000000..1a957b36eea35f7c0a5366645c3a62298ca56dea --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/options.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Command line options for performance test program +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/library/library.h" + +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Global options +class Options { +public: + + /// Cublas and cuDNN options + struct Library { + + // + // Data members + // + + /// Algorithm mode + AlgorithmMode algorithm_mode; + + /// Algorithm enumerants + std::vector algorithms; + + // + // Methods + // + + explicit Library(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to the selected device + struct Device { + + /// Device ID + std::vector devices; + + /// Number of total devices + /// This is not set by the user, it is set by automatically + int num_devices; + + /// CUDA Device properties + std::vector properties; + + /// Total memory allocation on each device + size_t maximum_capacity; + + private: + /// SM Count + /// Limits the number of SMs to use on each device + int sm_count; + + // + // Methods + // + public: + explicit Device(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + void print_device_info(std::ostream &out) const; + + /// Returns the device ID from a device index + int device_id(size_t device_index) const; + + /// Returns the sm_count if set, otherwise returns the number of SMs on the device + int get_sm_count(int device_index) const; + + /// Returns the compute capability of the listed devices (e.g. 70, 75, 80, etc.) + int compute_capability(int device_index) const; + }; + + /// Options related to initializing input tensors + struct Initialization { + + /// If true, data is initialized randomly. If false, no initialization is performed after + /// allocating tensors. + bool enabled; + + /// If true, data distribution is set by the user and is not allowed to change + /// If false, data distribution is allowed to change based on element_type (library::NumericTypeID) + bool fix_data_distribution; + + /// Data distribution for input tensors + Distribution data_distribution; + + /// Source of random tensor elements + library::Provider provider; + + /// Random number generator seed. + int seed; + + // + // Methods + // + + explicit Initialization(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Helper to parse a Distribution object from the command line parser + static void get_distribution( + cutlass::CommandLine const &args, + std::string const &arg, + cutlass::Distribution &dist); + }; + + /// Options related to verification of the result + struct Verification { + + // + // Data members + // + + /// If true, kernels are verified before they are profiled + bool enabled; + + /// If true, causes profiler to return an error code if no reference check is run. + /// Only valid when verification is enabled. + bool required; + + /// Relative error threshold - zero to require bit-level consistency + double epsilon; + + /// Values smaller than this are assumed to be zero + double nonzero_floor; + + /// List of providers used to verify each result + ProviderVector providers; + + /// Indicates when to save the workspace + SaveWorkspace save_workspace; + + // + // Methods + // + + explicit Verification(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to profiling + struct Profiling { + + /// Number of workspaces to rotate through to avoid cache-resident working sets + int workspace_count{0}; + + /// Number of iterations to warmup each kernel prior to profiling + int warmup_iterations{10}; + + /// Number of iterations to profile each kernel - if 0, kernels are launched up to the profiling duration + /// This will always override profiling-duration and min-iterations. + int iterations{100}; + + /// Time to spend profiling each kernel (ms) + int duration{10}; + + /// Minimum number of iterations to profile + int min_iterations{10}; + + /// If true, profiling with cuda graph enabled. + bool use_cuda_graphs{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// within the subset of kernels matching a kernel filter regex. The best + /// performance is determined by screening over a set of predefined M/N/K + /// sizes and performance-related parameters, including cluster shapes, + /// swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_kernel_performance_search{false}; + + /// If enabled, the CUTLASS profiler searches for the best-performing kernel + /// for a given M/N/K problem size by evaluating various performance-related + /// parameters such as cluster shapes, swizzle sizes, and rasterization orders. + /// For now, it only supports legacy GEMM and blockscaled GEMM. + bool enable_best_kernel_for_fixed_shape{false}; + + /// Number of ms to sleep between profiling periods (ms) + int sleep_duration{50}; + + /// If true, profiling is actually conducted. + bool enabled{true}; + + /// If true, profiling returns an error code if no kernels are found to match the filters. + bool error_on_no_match{false}; + + /// If true, profiling returns an error code if no kernel are profiled + // Sometimes the kernel matches but failed to profile (e.g. can_implement() error) + bool error_if_nothing_is_profiled{false}; + + /// List of providers of each functionality to be profiled + ProviderVector providers; + + // + // Methods + // + + explicit Profiling(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + /// Returns true if a provider is enabled + bool provider_enabled(library::Provider provider) const; + + /// Returns the index of a provider if its enabled + size_t index(library::Provider provider) const; + }; + + /// Options related to reporting + struct Report { + + /// If true, result is appended to possibly existing file + bool append; + + /// Path to a file containing results + std::string output_path; + + /// Path to a file containing junit xml results + std::string junit_output_path; + + /// Sequence of tags to attach to each result + std::vector> pivot_tags; + + /// If true, reports status of all kernels including those that were + /// not run for the given arguments + bool report_not_run; + + /// Prints human-readable text to stdout. If false, nothing is written to stdout + bool verbose; + + /// Sort results by flops-per-byte + bool sort_flops_per_byte; + + /// Sort results by flops-per-second + bool sort_flops_per_sec; + + /// Prints the name of the kernel being profiled before running the kernel. + /// This is useful for determining which kernel is causing a run of the profiler to hang + bool print_kernel_before_running; + + // + // Methods + // + + explicit Report(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + }; + + /// Options related to printing usage and version information + struct About { + + /// If true, usage is printed and the program ends. + bool help; + + /// Prints version string + bool version; + + /// Print information about devices + bool device_info; + + // + // Methods + // + + explicit About(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out, int indent = 0) const; + + static void print_version(std::ostream &out); + }; + +public: + + // + // Data members + // + + /// Top-level execution mode + ExecutionMode execution_mode; + + /// Name of math function to profile + library::OperationKind operation_kind; + + /// Vector of operation name substrings + std::vector operation_names; + + /// Map of problems to run for each operation + /// [operation_name] -> vector of problems, each problem specified as a vector of [argument name] -> [argument value] + std::unordered_map> operation_problems; + + /// Vector of operation name substrings + std::vector excluded_operation_names; + + + // + // Detailed configuration options + // + + /// Configuration + CommandLine cmdline; + Device device; + Initialization initialization; + Library library; + Verification verification; + Profiling profiling; + Report report; + About about; + +public: + + explicit Options(CommandLine const &cmdline); + + void print_usage(std::ostream &out) const; + void print_options(std::ostream &out) const; + + static std::string indent_str(int indent); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h new file mode 100644 index 0000000000000000000000000000000000000000..07102c99bc0f38a071e1ab828aab30678a3e2d44 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_report.h @@ -0,0 +1,128 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Class performing output during profiling +*/ + +#pragma once + +#include +#include + +// CUTLASS Profiler includes +#include "options.h" +#include "enumerated_types.h" +#include "performance_result.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +class PerformanceReport { +private: + + /// Reference to options + Options const &options_; + + /// Operation kind + library::OperationKind op_kind_; + + /// Operation file name containing performance report of op_kind + std::string op_file_name_; + + /// Output file containing results + std::ofstream output_file_; + + /// Operation file name containing junit performance report of op_kind + std::string op_junit_file_name_; + + /// Output file containing junit results + std::ofstream junit_output_file_; + + /// Flag indicating the performance report is valid + bool good_; + + /// Vector of argument names + std::vector argument_names_; + + /// Counter uniquely identifying problem within the report + size_t problem_index_; + + /// Collection of all results + PerformanceResultVector concatenated_results_; + +public: + + PerformanceReport(Options const &options, std::vector const &argument_names, library::OperationKind const &op_kind); + ~PerformanceReport(); + + bool good() const { return good_; } + + void next_problem(); + void append_result(PerformanceResult result); + void sort_flops_per_byte(PerformanceResultVector &results); + void sort_flops_per_sec(PerformanceResultVector &results); + void append_results(PerformanceResultVector const &results); + +public: + + /// Prints the CSV header + std::ostream & print_csv_header_(std::ostream &out); + + /// Prints the CSV + std::ostream & print_result_csv_(std::ostream &out, PerformanceResult const &result); + + /// @defgroup jUnit Result Generation + /// Functions related to generation of the jUnit results + /// @{ + + std::ostream & print_junit_header_(std::ostream &out); + std::ostream & print_junit_result_(std::ostream &out, PerformanceResult const &result); + std::ostream & print_junit_footer_(std::ostream &out); + + /// @} + + /// Prints the result in human readable form + std::ostream & print_result_pretty_( + std::ostream &out, + PerformanceResult const &result, + bool use_shell_coloring = true); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h new file mode 100644 index 0000000000000000000000000000000000000000..986ac89bc86a267ce8fb181a986f28f3f0936566 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/performance_result.h @@ -0,0 +1,137 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +// CUTLASS Profiler includes +#include "enumerated_types.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performance result object +struct PerformanceResult { + + /// Index of problem + size_t problem_index; + + /// library::Provider + library::Provider provider; + + /// Operation kind + library::OperationKind op_kind; + + /// CUTLASS status result from kernels (success or failure) + // Status does information on verification + Status status; + + /// Outcome of verification (worst case verification result) + Disposition disposition; + + /// Outcome of verification (all verification results) + DispositionMap verification_map; + + /// Operation name + std::string operation_name; + + /// Stringified vector of argument values + std::vector > arguments; + + /// Number of bytes read or written + int64_t bytes; + + /// Number of DL flops performed by the math function + int64_t flops; + + /// Average runtime in ms + double runtime; + + /// Average runtime in ms per device + std::vector runtime_vector; + + // + // Members + // + + /// Ctor + PerformanceResult(): + problem_index(0), + op_kind(library::OperationKind::kInvalid), + provider(library::Provider::kInvalid), + disposition(Disposition::kNotRun), + status(Status::kInvalid), + bytes(0), + flops(0), + runtime(0) + { } + + // Copy constructor for deep copy + PerformanceResult(const PerformanceResult& other) = default; + + // Explicitly define copy assignment operator + PerformanceResult& operator=(const PerformanceResult& other) = default; + + /// Returns true if the runtime is valid + bool good() const { + return runtime > 0; + } + + /// Math throughput in units of GFLOP/s + double gflops_per_sec() const { + return double(flops) / runtime / 1.0e6; + } + + /// memory bandwidth in units of GiB/s + double gbytes_per_sec() const { + return double(bytes) / double(1 << 30) / runtime * 1000.0; + } + +}; + +using PerformanceResultVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h new file mode 100644 index 0000000000000000000000000000000000000000..9bdbec657c10cff0dafebd2cb6cd52057f3695c9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/problem_space.h @@ -0,0 +1,1039 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + + "Any sufficiently complicated C or Fortran program contains an ad-hoc, informally-specified, + bug-ridden, slow implementation of half of Common Lisp." + + - Greenspun's Tenth Rule of Programming + + + cutlass::profiler::ProblemSpace defines a set of data structures which represent the Cartesian + product of sequences defined by integer ranges, lists of scalars, and sets of enumerated types. + + These permit a single invocation of the CUTLASS Profiler to iterate over a large set of problems, + verify and profile various operations when they are compatible with the command line, and + construct data tables of results that are convenient inputs to post processing in Excel or Pandas. + + By executing multiple problems per invocation, startup overheads may be amortized across many + kernel launches. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// CUTLASS Utility includes +#include "cutlass/util/command_line.h" + +// CUTLASS Library includes +#include "cutlass/library/library.h" + +// Profiler includes +#include "enumerated_types.h" + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines the argument schema +struct ArgumentDescription { + + /// Type of argument + ArgumentTypeID type; + + /// Prioritized array of aliases used in command line parsing + std::vector aliases; + + /// Description of argument + std::string description; + + // + // Methods + // + + /// Default ctor + ArgumentDescription(): + type(ArgumentTypeID::kInvalid) { } + + /// Constructor with aliases + ArgumentDescription( + ArgumentTypeID type_, + std::vector const &aliases_, + std::string const &description_ + ): + type(type_), aliases(aliases_), description(description_) { } +}; + +/// Vector of arguments +using ArgumentDescriptionVector = std::vector; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Base class for kernel arguments +struct KernelArgument { + + // + // Type definitions + // + + /// Value base class + struct Value { + + KernelArgument const *argument; + bool not_null; + + // + // Methods + // + + Value( + KernelArgument const *argument_ = nullptr, + bool not_null_ = true + ): argument(argument_), not_null(not_null_) { } + + virtual ~Value() { } + + virtual std::ostream &print(std::ostream &out) const =0; + }; + + /// Abstract base class to iterate over values within arguments + struct ValueIterator { + + /// Indicates type of kernel argument + KernelArgument const *argument; + + /// If the iterator points to an argument that is null, it needs to be distinguished + /// from end. + bool null_argument; + + // + // Methods + // + + /// Constructs a value iterator - no methods are valid if argument_ == nullptr + ValueIterator( + KernelArgument const *argument_ = nullptr, + bool null_argument_ = false): + argument(argument_), null_argument(null_argument_) { + + if (!argument_->not_null()) { + null_argument = true; + } + } + + virtual ~ValueIterator() { } + + /// Advances to next point in range + virtual void operator++() = 0; + + /// Compares against another value iterator - must be of the same KernelArgument type + virtual bool operator==(ValueIterator const &it) const = 0; + + /// Returns a unique_ptr object pointing to a newly created value object + virtual std::unique_ptr at() const = 0; + + /// Gets the type of the iterator + ArgumentTypeID type() const { + return argument->description->type; + } + + /// Helper to compute inequality + bool operator!=(ValueIterator const &it) const { + return !(*this == it); + } + + std::ostream &print(std::ostream &out) const; + }; + + // + // Data members + // + + /// Describes the argument + ArgumentDescription const *description; + + /// Parent node + KernelArgument *parent; + + /// Sequence in which the kernel argument is to be iterated over. + /// Smaller means faster changing. -1 is don't care + int ordinal; + + // + // Methods + // + + /// Default ctor + KernelArgument( + ArgumentDescription const *description_ = nullptr, + KernelArgument *parent_ = nullptr, + int ordinal_ = -1 + ): description(description_), parent(parent_), ordinal(ordinal_) { } + + virtual ~KernelArgument(); + + /// Returns true if the kernel argument iself is empty + virtual bool not_null() const =0; + + /// Returns a string name for debugging + std::string qualified_name() const { + if (description) { + if (description->aliases.empty()) { + return ""; + } + return description->aliases.front(); + } + return ""; + } + + virtual std::unique_ptr begin() const =0; + virtual std::unique_ptr end() const =0; +}; + +using KernelArgumentVector = std::vector>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a scalar argument type as a string that is lexically cast to the appropriate kernel +/// type. +struct ScalarArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct ScalarValue : public KernelArgument::Value { + + std::string value; + + // + // Methods + // + + ScalarValue( + std::string const &value_ = "", + ScalarArgument const *argument = nullptr, + bool not_null_ = true + ); + + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct ScalarValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit ScalarValueIterator(ScalarArgument const *argument = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit ScalarArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Closed range supporting additive increment +struct Range { + + // + // Type definitions + // + + enum class Mode { + kSequence, + kRandom, + kRandomLog2, + kInvalid + }; + + struct Iterator { + + int64_t value; + int64_t increment; + Range const *range; + + // + // Methods + // + + Iterator( + int64_t value_ = 0, + int64_t increment_ = 1, + Range const *range_ = nullptr + ): + value(value_), increment(increment_), range(range_) { } + + Iterator & operator++() { + value += increment; + return *this; + } + + Iterator operator++(int) { + Iterator self(*this); + ++(*this); + return self; + } + + bool operator==(Iterator const &it) const { + return value == it.value; + } + + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + static int64_t round(int64_t value, int64_t divisible) { + int64_t rem = (value % divisible); + + // Round either up or down + if (rem > divisible / 2) { + value += (divisible - rem); + } + else { + value -= rem; + } + + return value; + } + + int64_t at() const { + if (!range) { + return value; + } + + switch (range->mode) { + case Mode::kSequence: return value; + + case Mode::kRandom: { + double rnd = double(range->minimum) + + double(std::rand()) / double(RAND_MAX) * (double(range->maximum) - double(range->minimum)); + + int64_t value = int64_t(rnd); + + return round(value, range->divisible); + } + break; + + case Mode::kRandomLog2: { + double lg2_minimum = std::log(double(range->minimum)) / std::log(2.0); + double lg2_maximum = std::log(double(range->maximum)) / std::log(2.0); + double rnd = lg2_minimum + double(std::rand()) / double(RAND_MAX) * (lg2_maximum - lg2_minimum); + + int64_t value = int64_t(std::pow(2.0, rnd)); + + return round(value, range->divisible); + } + break; + default: break; + } + return value; + } + + int64_t operator*() const { + return at(); + } + }; + + // + // Data members + // + + int64_t first; ///< first element in range + int64_t last; ///< last element in range + int64_t increment; ///< additive increment between values + + Mode mode; ///< mode selection enables alternative values + int64_t minimum; ///< minimum value to return + int64_t maximum; ///< maximum value to return + int64_t divisible; ///< rounds value down to an integer multiple of this value + + // + // Methods + // + + /// Default constructor - range acts as a scalar + Range(int64_t first_ = 0): first(first_), last(first_), increment(1), mode(Mode::kSequence), minimum(0), maximum(0), divisible(1) { } + + /// Range acts as a range + Range( + int64_t first_, + int64_t last_, + int64_t increment_ = 1, + Mode mode_ = Mode::kSequence, + int64_t minimum_ = 0, + int64_t maximum_ = 0, + int64_t divisible_ = 1 + ): first(first_), last(last_), increment(increment_), mode(mode_), minimum(minimum_), maximum(maximum_), divisible(divisible_) { + + // Helpers to avoid constructing invalid ranges + if (increment > 0) { + if (last < first) { + std::swap(last, first); + } + } + else if (increment < 0) { + if (first < last) { + std::swap(last, first); + } + } + else if (last != first) { + last = first; + increment = 1; + } + } + + /// Helper to construct a sequence range + static Range Sequence(int64_t first_, int64_t last_, int64_t increment_ = 1) { + return Range(first_, last_, increment_, Mode::kSequence); + } + + /// Helper to construct a range that is a random distribution + static Range Random(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandom, minimum_, maximum_, divisible_); + } + + /// Helper to construct a range that is a random distribution over a log scale + static Range RandomLog2(int64_t minimum_, int64_t maximum_, int64_t count_, int64_t divisible_ = 1) { + return Range(1, count_, 1, Mode::kRandomLog2, minimum_, maximum_, divisible_); + } + + /// Returns an iterator to the first element within the range + Iterator begin() const { + return Iterator(first, increment, this); + } + + /// Returns an iterator to the first element *after* the range + Iterator end() const { + return Iterator(first + ((last - first)/increment + 1) * increment, increment, this); + } +}; + +/// Integer-valued argument - represented as a list of integer-valued ranges +struct IntegerArgument : public KernelArgument { + + // + // Type definitions + // + + /// Value type + struct IntegerValue : public KernelArgument::Value { + + int64_t value; + + // + // Methods + // + + IntegerValue( + int64_t value_ = 0, + IntegerArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Collection of ranges represent the IntegerArgument's state + using RangeCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct IntegerValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + RangeCollection::const_iterator range_it; + Range::Iterator value_it; + + // + // Methods + // + + IntegerValueIterator(); + IntegerValueIterator(IntegerArgument const *argument); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + /// Set of possible values + RangeCollection ranges; + + // + // Methods + // + + /// Default ctor + IntegerArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + bool _not_null = !ranges.empty(); + return _not_null; + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure defining the data type of tensors +struct TensorArgument : public KernelArgument { + + // + // Type definitions + // + + struct TensorDescription { + + /// Data type of elements + library::NumericTypeID element; + + /// Layout definition + library::LayoutTypeID layout; + + /// Computed extent + std::vector extent; + + /// Enables directly specifying stride value used to size tensor + std::vector stride; + + // + // Methods + // + + TensorDescription( + library::NumericTypeID element_ = library::NumericTypeID::kUnknown, + library::LayoutTypeID layout_ = library::LayoutTypeID::kUnknown, + std::vector extent_ = std::vector(), + std::vector stride_ = std::vector() + ): + element(element_), layout(layout_), extent(extent_), stride(stride_) {} + }; + + using ValueCollection = std::vector; + + /// Value structure + struct TensorValue : public KernelArgument::Value { + + TensorDescription desc; + + // + // Methods + // + + TensorValue( + TensorDescription const &desc_ = TensorDescription(), + TensorArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + /// Abstract base class to iterate over values within arguments + struct TensorValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit TensorValueIterator(TensorArgument const *argument_); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + /// Set of possible values + ValueCollection values; + + // + // Methods + // + + /// Default ctor + explicit TensorArgument( + ArgumentDescription const *description + ): + KernelArgument(description) { } + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Numeric data type +struct EnumeratedTypeArgument : public KernelArgument { + + // + // Type definitions + // + + struct EnumeratedTypeValue : public KernelArgument::Value { + + /// Data type of element + std::string element; + + // + // Methods + // + + EnumeratedTypeValue( + std::string const &element_ = std::string(), + EnumeratedTypeArgument const *argument_ = nullptr, + bool not_null_ = true + ); + + /// Pretty printer for debugging + virtual std::ostream &print(std::ostream &out) const; + }; + + using ValueCollection = std::vector; + + /// Abstract base class to iterate over values within arguments + struct EnumeratedTypeValueIterator : public KernelArgument::ValueIterator { + + // + // Data members + // + + ValueCollection::const_iterator value_it; + + // + // Methods + // + + explicit EnumeratedTypeValueIterator(EnumeratedTypeArgument const *argument_ = nullptr); + + virtual void operator++(); + virtual bool operator==(ValueIterator const &it) const; + + /// Gets the value pointed to + virtual std::unique_ptr at() const; + }; + + // + // Data members + // + + ValueCollection values; + + // + // Members + // + + /// Default ctor + explicit EnumeratedTypeArgument(ArgumentDescription const *description): + KernelArgument(description) {} + + virtual bool not_null() const { + return !values.empty(); + } + + virtual std::unique_ptr begin() const; + virtual std::unique_ptr end() const; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Object storing the space argument values +class ProblemSpace { +public: + + /// Tuple of arguments + using Problem = std::vector>; + + /// Type used to iterator over things + using IteratorVector = std::vector>; + + /// Iterates over points in the design space + class Iterator { + private: + + /// One iterator per argument + IteratorVector iterators; + + public: + + // + // Methods + // + + explicit Iterator(); + Iterator(ProblemSpace const &problem_space); + Iterator(Iterator &&it); + + // Rule of three + Iterator(Iterator const &) = delete; + Iterator &operator=(Iterator const &it) = delete; + ~Iterator() = default; + + /// Pre-increment - advances to next point in argument range + void operator++(); + + /// Gets the current argument value + Problem at() const; + + /// Moves iterator to end + void move_to_end(); + + /// Equality operator + bool operator==(Iterator const &it) const; + + /// Inequality operator + bool operator!=(Iterator const &it) const { + return !(*this == it); + } + + /// Helper to call at() method + Problem operator*() const { + return at(); + } + + /// Helper to print iterator state + std::ostream & print(std::ostream &out) const; + + private: + + /// Helper for recursively constructing iterators + void construct_(KernelArgument const *argument); + }; + +public: + + // + // Data members + // + + KernelArgumentVector arguments; + + /// Map of argument names to their position within the argument vector + std::unordered_map argument_index_map; + +public: + + // + // Methods + // + + /// Default ctor + ProblemSpace() = default; + + /// Constructs a problem space from a vector of arguments. This vector must outlive + /// the ProblemSpace object, which stores pointers to objects within the + /// ArgumentDescriptionVector. + ProblemSpace(ArgumentDescriptionVector const &schema, CommandLine const &cmdline); + + Iterator begin() const; // returns an iterator to the first point in the range + Iterator end() const; // returns an iterator to the first point after the range + + /// Returns the index of an argument by name + size_t argument_index(char const *name) const; + + /// Gets all argument names as an ordered vector + std::vector argument_names() const; + + /// Returns the number of dimensions of the problem space + size_t rank() const { return arguments.size(); } + +private: + + /// Helper for recursively cloning + void clone_( + KernelArgumentVector &kernel_args, + ArgumentDescription const *arg_desc); + + /// Parses command line argument + void parse_( + KernelArgument *arg, + CommandLine const &cmdline); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Lexically casts an argument to an int if it is defined. Returns true if not null. +bool arg_as_int(int &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int(int64_t &int_value, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_int( + int64_t &int_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_bool(bool &bool_value, KernelArgument::Value const *value_ptr); + +bool arg_as_bool(bool &bool_value, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID(library::NumericTypeID &numeric_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_NumericTypeID( + library::NumericTypeID &numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID(library::LayoutTypeID &layout_type, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_LayoutTypeID( + library::LayoutTypeID &layout_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID(library::OpcodeClassID &opcode_class, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_OpcodeClassID( + library::OpcodeClassID &opcode_class, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID(library::SplitKMode &split_k_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_SplitKModeID( + library::SplitKMode &split_k_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID(library::ConvModeID &conv_mode, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ConvModeID( + library::ConvModeID &conv_mode, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID(library::IteratorAlgorithmID &iterator_algorithm, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_IteratorAlgorithmID( + library::IteratorAlgorithmID &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype(library::RuntimeDatatype &runtime_datatype, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RuntimeDatatype( + library::RuntimeDatatype &runtime_datatype, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder(library::RasterOrder &raster_order, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_RasterOrder( + library::RasterOrder &raster_order, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID(library::Provider &provider, KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to an int64 if it is defined. Returns true if not null. +bool arg_as_ProviderID( + library::Provider &provider, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + KernelArgument::Value const *value_ptr); + +/// Lexically casts an argument to a given type stored in a byte array. Returns true if not null. +bool arg_as_scalar( + std::vector &bytes, + library::NumericTypeID numeric_type, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +bool arg_as_string( + std::string& arg, + char const* name, + ProblemSpace const& problem_space, + ProblemSpace::Problem const& problem); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + TensorArgument::TensorValue const *value_ptr); + +/// Returns true if a tensor description satisfies a `tensor` value +bool tensor_description_satisfies( + library::TensorDescription const &tensor_desc, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a conv kind satisfies the value +bool conv_kind_satisfies( + library::ConvKind const &conv_kind, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + EnumeratedTypeArgument::EnumeratedTypeValue const *value_ptr); + +/// Returns true if a iterator algorithm satisfies the value +bool iterator_algorithm_satisfies( + library::IteratorAlgorithmID const &iterator_algorithm, + char const *name, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba47a6832077984c334a5467257a151735b088b3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_2k_operation_profiler.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class Rank2KOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + Rank2KOperationProfiler(Options const &options); + + /// Destructor + virtual ~Rank2KOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..fff190a7570cd5811c6e5de6284bf96e40c404b7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/rank_k_operation_profiler.h @@ -0,0 +1,227 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class RankKOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct RankKProblem { + int64_t n; + int64_t k; + int64_t lda; + int64_t ldc; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + RankKProblem(): + n(16), k(16), lda(0), ldc(0), + fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::RankKDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::RankKDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct RankKWorkspace { + + DeviceAllocation *A; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::RankKConfiguration configuration; + library::RankKArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + RankKWorkspace(): + A(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + RankKProblem problem_; + + /// Device memory allocations + RankKWorkspace rank_k_workspace_; + + +public: + // + // Methods + // + + /// Ctor + RankKOperationProfiler(Options const &options); + + /// Destructor + virtual ~RankKOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::RankKDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..0c81ef4637175a6de1f44cedddf319436aaff24d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/reduction_operation_profiler.h @@ -0,0 +1,173 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines profiling functionality for reduction operation + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#if CUTLASS_ENABLE_CUDNN +#include "cudnn_helpers.h" +#endif //#if CUTLASS_ENABLE_CUDNN +#include "debug.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class ReductionOperationProfiler : public OperationProfiler { +public: + + + /// Workspace used + struct ReductionWorkspace { + + /// Conv device allocations + DeviceAllocation *Workspace; + DeviceAllocation *Source; + DeviceAllocation *Destination; + DeviceAllocation *Reference; + + /// Library configuration and arguments + library::ReductionConfiguration configuration; + library::ReductionArguments arguments; + + /// Buffer used for the cutlass operations' host workspace + std::vector host_workspace; + + /// Buffer used for the cutlass operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + ReductionWorkspace(): + Workspace(nullptr), Source(nullptr), Destination(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// Reduction problem obtained from problem space + MatrixCoord problem_; + + /// Device memory allocations + ReductionWorkspace conv_workspace_; + + +public: + // + // Methods + // + + /// Ctor + ReductionOperationProfiler(Options const &options); + + /// Destructor + virtual ~ReductionOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..60204d8c9d458ab12020a6492de23174739aa584 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/sparse_gemm_operation_profiler.h @@ -0,0 +1,214 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" +#include "gemm_operation_profiler.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class SparseGemmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SparseGemmProblem { + int64_t m; + int64_t n; + int64_t k; + int64_t lda; + int64_t ldb; + int64_t ldc; + int64_t lde; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + static int const sparse = 2; + // every 128b ElementA uses one elementE + int elements_per_128b; + + // + // Methods + // + + SparseGemmProblem(): + m(16), n(16), k(16), lda(0), ldb(0), ldc(0), lde(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SparseGemmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *E; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SparseGemmConfiguration configuration; + library::SparseGemmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SparseGemmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), E(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + // GEMM problem + SparseGemmProblem problem_; + + /// Device memory allocations + SparseGemmWorkspace gemm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SparseGemmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SparseGemmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SparseGemmDescription const &operation_desc, + ProblemSpace const &problem_space); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..94ded5e803bf914e5ae8c4ebb867cfe42ef829bc --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/symm_operation_profiler.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Abstract base class for each math function +class SymmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct SymmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldc; + SideMode side_mode; + FillMode fill_mode; + BlasMode blas_mode; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + SymmProblem(): + m(16), n(16), lda(0), ldb(0), ldc(0), + side_mode(SideMode::kInvalid), fill_mode(FillMode::kInvalid), blas_mode(BlasMode::kInvalid), + split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Total number of bytes loaded + int64_t bytes(library::SymmDescription const &operation_desc) const; + + /// Total number of flops computed + int64_t flops(library::SymmDescription const &operation_desc) const; + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct SymmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *C; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::SymmConfiguration configuration; + library::SymmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + SymmWorkspace(): + A(nullptr), B(nullptr), C(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + SymmProblem problem_; + + /// Device memory allocations + SymmWorkspace symm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + SymmOperationProfiler(Options const &options); + + /// Destructor + virtual ~SymmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::SymmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..9f21dafa0ecc869840fdba0a9c4414a89bbf4a7d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/profiler/include/cutlass/profiler/trmm_operation_profiler.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines a math function + + +*/ + +#pragma once + +#include +#include +#include +#include +#include + +// CUTLASS Library includes +#include "cutlass/blas3.h" +#include "cutlass/library/library.h" +#include "cutlass/library/util.h" +#include "cutlass/library/manifest.h" + +// Profiler includes +#include "options.h" +#include "device_context.h" +#include "operation_profiler.h" +#include "performance_result.h" +#include "problem_space.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace profiler { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Abstract base class for each math function +class TrmmOperationProfiler : public OperationProfiler { +public: + + /// Problem structure obtained from problem space + struct TrmmProblem { + int64_t m; + int64_t n; + int64_t lda; + int64_t ldb; + int64_t ldd; + SideMode side_mode; + FillMode fill_mode; + DiagType diag_type; + std::vector alpha; + std::vector beta; + int64_t split_k_slices; + int64_t batch_count; + + // + // Methods + // + + TrmmProblem(): + m(16), n(16), lda(0), ldb(0), ldd(0), split_k_slices(1), batch_count(1) { } + + /// Parses the problem + Status parse( + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes a performance result + void initialize_result( + PerformanceResult &result, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + }; + + /// Workspace used + struct TrmmWorkspace { + + DeviceAllocation *A; + DeviceAllocation *B; + DeviceAllocation *D; + DeviceAllocation *Computed; + DeviceAllocation *Reference; + + library::TrmmConfiguration configuration; + library::TrmmArguments arguments; + + /// Buffer used for the operation's host workspace + std::vector host_workspace; + + /// Buffer used for the operations' device workspace + DeviceAllocation device_workspace; + + // + // Methods + // + + TrmmWorkspace(): + A(nullptr), B(nullptr), D(nullptr), Computed(nullptr), Reference(nullptr) { } + }; + +protected: + + // + // Data members + // + + /// GEMM problem obtained from problem space + TrmmProblem problem_; + + /// Device memory allocations + TrmmWorkspace trmm_workspace_; + + +public: + // + // Methods + // + + /// Ctor + TrmmOperationProfiler(Options const &options); + + /// Destructor + virtual ~TrmmOperationProfiler(); + + /// Prints usage statement for the math function + virtual void print_usage(std::ostream &out) const; + + /// Prints examples + virtual void print_examples(std::ostream &out) const; + + /// Extracts the problem dimensions + virtual Status initialize_configuration( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Initializes workspace + virtual Status initialize_workspace( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Verifies CUTLASS against references + virtual bool verify_cutlass( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + + /// Measures performance results + virtual bool profile( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +protected: + + /// Initializes the performance result + void initialize_result_( + PerformanceResult &result, + Options const &options, + library::TrmmDescription const &operation_desc, + ProblemSpace const &problem_space); + + /// Verifies CUTLASS against references + bool verify_with_cublas_( + Options const &options, + PerformanceReport &report, + DeviceContext &device_context, + library::Operation const *operation, + ProblemSpace const &problem_space, + ProblemSpace::Problem const &problem); + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace profiler +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c2727c989e645eca8e67a5d8d50391ced803cffa --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/GPU_Clock.hpp @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +struct GPU_Clock +{ + GPU_Clock() { + cudaEventCreate(&start_); + cudaEventCreate(&stop_); + cudaEventRecord(start_); + } + + ~GPU_Clock() { + cudaEventDestroy(start_); + cudaEventDestroy(stop_); + } + + void start() { + cudaEventRecord(start_); + } + + float milliseconds() { + cudaEventRecord(stop_); + cudaEventSynchronize(stop_); + float time; + cudaEventElapsedTime(&time, start_, stop_); + return time; + } + + float seconds() { + return milliseconds() * float(1e-3); + } + + private: + cudaEvent_t start_, stop_; +}; diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h new file mode 100644 index 0000000000000000000000000000000000000000..c95bd1cbeb56cc566394b155ea7ac24f07c28162 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/command_line.h @@ -0,0 +1,324 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * Utility for parsing command line arguments + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/****************************************************************************** + * command_line + ******************************************************************************/ + +/** + * Utility for parsing command line arguments + */ +struct CommandLine { + std::vector keys; + std::vector values; + std::vector args; + + /** + * Constructor + */ + CommandLine(int argc, const char** argv) { + using namespace std; + + for (int i = 1; i < argc; i++) { + string arg = argv[i]; + + if ((arg[0] != '-') || (arg[1] != '-')) { + args.push_back(arg); + continue; + } + + string::size_type pos; + string key, val; + if ((pos = arg.find('=')) == string::npos) { + key = string(arg, 2, arg.length() - 2); + val = ""; + } else { + key = string(arg, 2, pos - 2); + val = string(arg, pos + 1, arg.length() - 1); + } + + keys.push_back(key); + values.push_back(val); + } + } + + /** + * Constructor to represent a command line from a map of [argument] -> [value] + */ + CommandLine(std::unordered_map& arg_map) { + for (const auto& [key, value] : arg_map) { + keys.push_back(key); + values.push_back(value); + } + } + + /** + * Checks whether a flag "--" is present in the commandline + */ + bool check_cmd_line_flag(const char* arg_name) const { + using namespace std; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) return true; + } + return false; + } + + /** + * Returns number of naked (non-flag and non-key-value) commandline parameters + */ + size_t num_naked_args() const { + return args.size(); + } + + /** + * Print naked (non-flag and non-key-value) commandline parameters + */ + void print_naked_args(std::ostream &out) const { + for (auto arg : args) { + out << " " << arg <<"\n"; + } + } + + /** + * Returns the commandline parameter for a given index (not including flags) + */ + template + void get_cmd_line_argument(size_t index, value_t& val) const { + using namespace std; + if (index < args.size()) { + istringstream str_stream(args[index]); + str_stream >> val; + } + } + + /** + * Obtains the boolean value specified for a given commandline parameter --= + */ + void get_cmd_line_argument(const char* arg_name, bool& val, bool _default) const { + val = _default; + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + val = !(value == "0" || value == "false"); + } + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val) const { + + get_cmd_line_argument(arg_name, val, val); + } + + /** + * Obtains the value specified for a given commandline parameter --= + */ + template + void get_cmd_line_argument(const char* arg_name, + value_t& val, + value_t const& _default) const { + using namespace std; + + val = _default; + + for (int i = 0; i < int(keys.size()); ++i) { + if (keys[i] == string(arg_name)) { + istringstream str_stream(values[i]); + str_stream >> val; + } + } + } + + /** + * Returns the values specified for a given commandline parameter --=,* + */ + template + void get_cmd_line_arguments(const char* arg_name, + std::vector& vals, + char sep = ',') const { + using namespace std; + + if (check_cmd_line_flag(arg_name)) { + // Clear any default values + vals.clear(); + + // Recover from multi-value string + for (size_t i = 0; i < keys.size(); ++i) { + if (keys[i] == string(arg_name)) { + string val_string(values[i]); + separate_string(val_string, vals, sep); + } + } + } + } + + /** + * Returns the values specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_pairs(const char* arg_name, + std::vector >& tokens, + char delim = ',', + char sep = ':') const { + if (check_cmd_line_flag(arg_name)) { + std::string value; + get_cmd_line_argument(arg_name, value); + + tokenize(tokens, value, delim, sep); + } + } + + /** + * Returns a list of ranges specified for a given commandline parameter + * --=,* + */ + void get_cmd_line_argument_ranges(const char* arg_name, + std::vector >& vals, + char delim = ',', + char sep = ':') const { + std::vector ranges; + get_cmd_line_arguments(arg_name, ranges, delim); + + for (std::vector::const_iterator range = ranges.begin(); + range != ranges.end(); ++range) { + + std::vector range_vals; + separate_string(*range, range_vals, sep); + vals.push_back(range_vals); + } + } + + /** + * The number of pairs parsed + */ + int parsed_argc() const { return (int)keys.size(); } + + //------------------------------------------------------------------------- + // Utility functions + //------------------------------------------------------------------------- + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector >& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + // Home-built to avoid Boost dependency + size_t s_idx = 0; + size_t d_idx = std::string::npos; + while (s_idx < str.size()) { + d_idx = str.find_first_of(delim, s_idx); + + size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size()); + size_t sep_idx = str.find_first_of(sep, s_idx); + size_t offset = 1; + if (sep_idx == std::string::npos || sep_idx >= end_idx) { + sep_idx = end_idx; + offset = 0; + } + + std::pair item( + str.substr(s_idx, sep_idx - s_idx), + str.substr(sep_idx + offset, end_idx - sep_idx - offset)); + + tokens.push_back(item); + s_idx = end_idx + 1; + } + } + + /// Tokenizes a comma-delimited list of string pairs delimited by ':' + static void tokenize(std::vector& tokens, + std::string const& str, + char delim = ',', + char sep = ':') { + typedef std::vector > TokenVector; + typedef TokenVector::const_iterator token_iterator; + + std::vector > token_pairs; + tokenize(token_pairs, str, delim, sep); + for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) { + tokens.push_back(tok->first); + } + } + + template + static void separate_string(std::string const& str, + std::vector& vals, + char sep = ',') { + std::istringstream str_stream(str); + std::string::size_type old_pos = 0; + std::string::size_type new_pos = 0; + + // Iterate -delimited values + value_t val; + while ((new_pos = str.find(sep, old_pos)) != std::string::npos) { + if (new_pos != old_pos) { + str_stream.width(new_pos - old_pos); + str_stream >> val; + vals.push_back(val); + } + + // skip over delimiter + str_stream.ignore(1); + old_pos = new_pos + 1; + } + + // Read last value + str_stream >> val; + vals.push_back(val); + } +}; + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8ace1e0a232ea7cccbb2089ec8432783c49410dd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/cublas_wrappers.hpp @@ -0,0 +1,528 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +//-- BLAM_DEBUG_OUT --------------------------------------------------------- +#ifdef BLAM_DEBUG +# include +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) std::cerr << "BLAM: " << msg << std::endl +# define BLAM_DEBUG_OUT_2(msg) std::cerr << msg << std::endl +# endif // BLAM_DEBUG_OUT +#else +# ifndef BLAM_DEBUG_OUT +# define BLAM_DEBUG_OUT(msg) +# define BLAM_DEBUG_OUT_2(msg) +# endif // BLAM_DEBUG_OUT +#endif // BLAM_DEBUG + +// User could potentially define ComplexFloat/ComplexDouble instead of std:: +#ifndef BLAM_COMPLEX_TYPES +#define BLAM_COMPLEX_TYPES 1 +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(complex) + +namespace blam { +template +using Complex = cuda::std::complex; +using ComplexFloat = cuda::std::complex; +using ComplexDouble = cuda::std::complex; +} +#endif // BLAM_COMPLEX_TYPES + +// User could potentially define Half instead of cute:: +#ifndef BLAM_HALF_TYPE +#define BLAM_HALF_TYPE 1 +#include +namespace blam { +using Half = cute::half_t; +} +#endif // BLAM_HALF_TYPE + +namespace blam +{ +namespace cublas +{ + +inline const char* +cublas_get_error(cublasStatus_t status) +{ + switch (status) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED -- The cuBLAS library was not initialized."; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED -- Resource allocation failed inside the cuBLAS library."; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE -- An unsupported value or parameter was passed to the function."; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH -- The function requires a feature absent from the device architecture."; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR -- An access to GPU memory space failed."; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED -- The GPU program failed to execute."; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR -- An internal cuBLAS operation failed."; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED -- The functionality requested is not supported."; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR -- An error was detected when checking the current licensing."; + default: + return "CUBLAS_ERROR -- "; + } +} + +inline bool +cublas_is_error(cublasStatus_t status) +{ + return status != CUBLAS_STATUS_SUCCESS; +} + + +// hgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const Half* beta, + Half* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasHgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + reinterpret_cast(beta), + reinterpret_cast< __half*>(C), CUDA_R_16F, ldC, + CUDA_R_16F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// mixed hf gemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const Half* A, int ldA, + const Half* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasGemmEx mixed half-float"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + reinterpret_cast(A), CUDA_R_16F, ldA, + reinterpret_cast(B), CUDA_R_16F, ldB, + beta, + C, CUDA_R_32F, ldC, + CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// igemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const int32_t* alpha, + const int8_t* A, int ldA, + const int8_t* B, int ldB, + const int32_t* beta, + int32_t* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasIgemm"); + + return cublasGemmEx(handle, transA, transB, + m, n, k, + alpha, + A, CUDA_R_8I, ldA, + B, CUDA_R_8I, ldB, + beta, + C, CUDA_R_32I, ldC, + CUDA_R_32I, CUBLAS_GEMM_DEFAULT_TENSOR_OP); +} + +// sgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, + const float* B, int ldB, + const float* beta, + float* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasSgemm"); + + return cublasSgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// dgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, + const double* B, int ldB, + const double* beta, + double* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasDgemm"); + + return cublasDgemm(handle, transA, transB, + m, n, k, + alpha, + A, ldA, + B, ldB, + beta, + C, ldC); +} + +// cgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, + const ComplexFloat* B, int ldB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasCgemm"); + + return cublasCgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// zgemm +inline cublasStatus_t +gemm(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, + const ComplexDouble* B, int ldB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC) +{ + BLAM_DEBUG_OUT("cublasZgemm"); + + return cublasZgemm(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, + reinterpret_cast(B), ldB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* A, int ldA, int loA, + const Half* B, int ldB, int loB, + const Half* beta, + Half* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmStridedBatched"); + + return cublasHgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast<__half*>(C), ldC, loC, + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* A, int ldA, int loA, + const float* B, int ldB, int loB, + const float* beta, + float* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmStridedBatched"); + + return cublasSgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* A, int ldA, int loA, + const double* B, int ldB, int loB, + const double* beta, + double* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmStridedBatched"); + + return cublasDgemmStridedBatched(handle, transA, transB, + m, n, k, + alpha, + A, ldA, loA, + B, ldB, loB, + beta, + C, ldC, loC, + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* A, int ldA, int loA, + const ComplexFloat* B, int ldB, int loB, + const ComplexFloat* beta, + ComplexFloat* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmStridedBatched"); + + return cublasCgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* A, int ldA, int loA, + const ComplexDouble* B, int ldB, int loB, + const ComplexDouble* beta, + ComplexDouble* C, int ldC, int loC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmStridedBatched"); + + return cublasZgemmStridedBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), ldA, loA, + reinterpret_cast(B), ldB, loB, + reinterpret_cast(beta), + reinterpret_cast(C), ldC, loC, + batch_size); +} + +// hgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const Half* alpha, + const Half* const A[], int ldA, + const Half* const B[], int ldB, + const Half* beta, + Half* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasHgemmBatched"); + + return cublasHgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(const_cast(A)), ldA, + // A, ldA, // cuBLAS 9.2 + reinterpret_cast(const_cast(B)), ldB, + // B, ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + reinterpret_cast<__half**>(const_cast(C)), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// sgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const float* alpha, + const float* const A[], int ldA, + const float* const B[], int ldB, + const float* beta, + float* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasSgemmBatched"); + + return cublasSgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// dgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const double* alpha, + const double* const A[], int ldA, + const double* const B[], int ldB, + const double* beta, + double* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasDgemmBatched"); + + return cublasDgemmBatched(handle, transA, transB, + m, n, k, + alpha, + const_cast(A), ldA, + // A, ldA, // cuBLAS 9.2 + const_cast(B), ldB, + // B, ldB, // cuBLAS 9.2 + beta, + const_cast(C), ldC, + // C, ldC, // cuBLAS 9.2 + batch_size); +} + +// cgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexFloat* alpha, + const ComplexFloat* const A[], int ldA, + const ComplexFloat* const B[], int ldB, + const ComplexFloat* beta, + ComplexFloat* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasCgemmBatched"); + + return cublasCgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +// zgemm +inline cublasStatus_t +gemm_batch(cublasHandle_t handle, + cublasOperation_t transA, cublasOperation_t transB, + int m, int n, int k, + const ComplexDouble* alpha, + const ComplexDouble* const A[], int ldA, + const ComplexDouble* const B[], int ldB, + const ComplexDouble* beta, + ComplexDouble* const C[], int ldC, + int batch_size) +{ + BLAM_DEBUG_OUT("cublasZgemmBatched"); + + return cublasZgemmBatched(handle, transA, transB, + m, n, k, + reinterpret_cast(alpha), + const_cast(reinterpret_cast(A)), ldA, + //reinterpret_cast(A), ldA, // cuBLAS 9.2 + const_cast(reinterpret_cast(B)), ldB, + //reinterpret_cast(B), ldB, // cuBLAS 9.2 + reinterpret_cast(beta), + const_cast(reinterpret_cast(C)), ldC, + //reinterpret_cast(C), ldC, // cuBLAS 9.2 + batch_size); +} + +} // end namespace cublas +} // end namespace blam diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h new file mode 100644 index 0000000000000000000000000000000000000000..88481a82e0e08f06b54c07c946d28160d41f9f07 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/debug.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Contains code for debugging cutlass code +*/ + +#pragma once + +#include "device_dump.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUDA_LOG) +#if !defined(__CUDA_ARCH__) +#define CUDA_LOG(format, ...) printf(format, __VA_ARGS__) +#else +#define CUDA_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, \ + blockIdx.y, \ + blockIdx.z, \ + threadIdx.x, \ + threadIdx.y, \ + threadIdx.z, \ + __VA_ARGS__); +#endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUDA_LOG_DEBUG) +#ifdef DEBUG +#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__) +#else +#define CUDA_LOG_DEBUG(format, ...) +#endif +#endif + +/** + * \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) + * along with the supplied source context. + * + * \return The CUDA error. + */ +__host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl(cudaError_t error, + const char* expression, + const char* filename, + int line) { + (void)filename; + (void)line; + if (error) { +#if !defined(__CUDA_ARCH__) + fprintf( + stderr, "CUDA error %d [%s, %d] in expression '%s': %s\n", error, filename, line, expression, cudaGetErrorString(error)); + fflush(stderr); +#else + printf("CUDA error %d [%s, %d] in expression '%s'\n", error, filename, line, expression); +#endif + } + return error; +} + +/** + * \brief Perror macro + */ +#ifndef CUDA_PERROR +#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__) +#endif + +/** + * \brief Perror macro with exit + */ +#ifndef CUDA_PERROR_EXIT +#define CUDA_PERROR_EXIT(e) \ + do { if (cuda_perror_impl((cudaError_t)(e), #e, __FILE__, __LINE__)) { \ + exit(1); \ + } } while (0) +#endif + +/** + * \brief Perror macro only if DEBUG is defined + */ +#ifndef CUDA_PERROR_DEBUG +#ifdef DEBUG +#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e) +#else +#define CUDA_PERROR_DEBUG(e) (e) +#endif +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// A small helper class to dump a type at compile time +// Usage:: DumpType::Class +template +struct DebugType {}; + +template +void DebugTypeFunc(T const& t) { + T::t; +} + +// A small helper class to dump a compile time constant at compile time +// Usage: DumpValue::kConstant +template +struct DebugValue {}; diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h new file mode 100644 index 0000000000000000000000000000000000000000..a73a8cfe79dd22c2d298fcb3be8cf25d5e3f5734 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_dump.h @@ -0,0 +1,187 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/cutlass.h" + +/** + * \file + * \brief C++ interface to dump fragments and shared memory contents for + * debugging. + */ + +namespace cutlass { +namespace debug { + +/****************************************************************************** + * Dump the fragments + ******************************************************************************/ + +/// The first N threads dump the first M elements from their fragments with a +/// stride of S elements. If N is not specified, dump the data of all the +/// threads. If M is not specified, dump all the elements of the fragment. +template +CUTLASS_DEVICE void dump_fragment(Fragment const& frag, int N = 0, int M = 0, + int S = 1) { + int total_threads = blockDim.x * blockDim.y * blockDim.z; + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (N < 0 || N > total_threads) { + if (thread_id == 0 && block_id == 0) + printf("Thread number N = %d should between [1, %d].\n", N, + total_threads); + + __syncthreads(); + + return; + } + + int total_elements = int(frag.size()); + + if (M < 0 || M > total_elements) { + if (thread_id == 0 && block_id == 0) + printf("Element number M = %d should between [1, %d].\n", M, + total_elements); + + __syncthreads(); + + return; + } + + if (N == 0) N = total_threads; + + if (M == 0) M = total_elements; + + if (S < 1 || S > M) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, M); + + __syncthreads(); + + return; + } + + if (thread_id == 0 && block_id == 0) + printf("\n*******************Dumping the fragments*******************\n\n"); + + CUTLASS_PRAGMA_NO_UNROLL + for (int tid = 0; tid < N; ++tid) { + if (tid == thread_id) { + printf("TB%d W%d T%d: ", block_id, tid / 32, tid & 31); + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < M; i += S) { + printf("%.0f ", float(typename Fragment::value_type(frag[i]))); + } + printf("\n"); + } + + __syncthreads(); + } + + if (thread_id == 0 && block_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} + +/****************************************************************************** + * Dump the shared memory + ******************************************************************************/ + +#define SHMEM_ROW_SIZE 128 + +/// Dump the shared memory contents. ptr is the begin address, size specifies +/// the number of elements that need to be dumped, and S specifies the stride. +template +CUTLASS_DEVICE void dump_shmem(Element const* ptr, size_t size, int S = 1) { + int block_id = + blockIdx.x + blockIdx.y * gridDim.x + gridDim.x * gridDim.y * blockIdx.z; + int thread_id = (threadIdx.z * (blockDim.x * blockDim.y)) + + (threadIdx.y * blockDim.x) + threadIdx.x; + + if (ptr == nullptr) { + if (thread_id == 0 && block_id == 0) printf("ptr is null.\n"); + + __syncthreads(); + return; + } + + if (size < 1) { + if (thread_id == 0 && block_id == 0) + printf("Element size is less than 1\n"); + + __syncthreads(); + + return; + } + + int row_elements = SHMEM_ROW_SIZE / sizeof(Element); + + if (S < 1 || S > row_elements) { + if (thread_id == 0 && block_id == 0) + printf("Stride S = %d should between [1, %d].\n", S, row_elements); + + __syncthreads(); + + return; + } + + __syncthreads(); + + if (thread_id == 0) + printf("\n********Dumping the shared memory of TB %d*******\n\n", block_id); + + if (thread_id == 0) { + for (int i = 0; i < size; i += row_elements) { + for (int j = 0; j < row_elements; j += S) { + printf("%.0f ", float(ptr[i + j])); + } + + printf("\n"); + } + } + + if (thread_id == 0) + printf("\n***********************************************************\n\n"); + + __syncthreads(); + + return; +} +} // namespace debug +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..59457b2e8122f46e443844fe276b2c7fb35f3f56 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_groupnorm.h @@ -0,0 +1,402 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do group norm on a device memory tensor with NHWC layout. The tensor will be divided into [N, H, W, G, C'] and then we do normalization on [H, W, C']. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do group norm on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +extern __shared__ char groupnorm_shm[]; + +// For small prod_dim1_to_last_dim/num_groups, to avoid multiple loads from global memory, +// we store the input in the shared memory. +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_store_locally(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + T* local_val = ((T*)groupnorm_shm) + TVecs_PER_THREAD * T_PER_TVec * tid; + float local_sum[1] = {0.0f}; + +// load from global memory into shared memory +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + local_val[local_val_offset + j] = tmp_vec_ptr[j]; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int local_val_offset = i * T_PER_TVec; +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(local_val[local_val_offset + j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + const int local_val_offset = i * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = (static_cast(local_val[local_val_offset + j]) - s_mean) * s_variance + * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = tmp_vec; + } + } +} + +// For large prod_dim1_to_last_dim/num_groups, +// in which the data cannot be stored locally, +// we will load from global memory multiple times, +// grid(num_groups, dim0) +// block(BLOCKSIZE) +// BLOCKSIZE * TVecs_PER_THREAD <= prod_dim1_to_last_dim/num_group +template +__global__ void groupnorm_twopass_multiple_load(T* output, + const T* input, + const T* gamma, + const T* beta, + int num_groups, + int prod_dim1_to_last_dim, + int last_dim, + const float eps, + const int TVecs_PER_THREAD) +{ + const int bid = blockIdx.y; // index of batch + const int gid = blockIdx.x; // index of group + const int tid = threadIdx.x; // index of thread + const int bdimx = blockDim.x; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int v_reduce_elements = s_reduce_elements / T_PER_TVec; + const int s_group_stride = last_dim / num_groups; + const int v_group_stride = s_group_stride / T_PER_TVec; + const int offset_of_group = (bid * prod_dim1_to_last_dim + gid * s_group_stride) / T_PER_TVec; + const TVec* input_TVec_ptr = (const TVec*)(input) + offset_of_group; + TVec* output_TVec_ptr = (TVec*)(output) + offset_of_group; + float local_sum[1] = {0.0f}; + +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + local_sum[0] += tmp; + } + } + } + __shared__ float s_mean, s_variance; + + // reduction for mean + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_mean = local_sum[0] / s_reduce_elements; + } + __syncthreads(); + + // reduction for std + local_sum[0] = 0.0f; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = static_cast(tmp_vec_ptr[j]); + tmp -= s_mean; + local_sum[0] += tmp * tmp; + } + } + } + if (bdimx <= 32) { + warpReduceSum(local_sum); + } + else { + blockReduceSum(local_sum); + } + if (tid == 0) { + s_variance = rsqrtf(local_sum[0] / s_reduce_elements + eps); + } + __syncthreads(); + + // normalize + const int gamma_offset_of_group = gid * v_group_stride; + const TVec* gamma_TVec_ptr = (const TVec*)gamma + gamma_offset_of_group; + const TVec* beta_TVec_ptr = (const TVec*)beta + gamma_offset_of_group; +#pragma unroll + for (int i = 0; i < TVecs_PER_THREAD; i += 1) { + const int current_load_start_idx = (i * bdimx + tid) * T_PER_TVec; + if (current_load_start_idx < s_reduce_elements) { + const int offset_in_group = + ((current_load_start_idx / s_group_stride) * last_dim + (current_load_start_idx % s_group_stride)) + / T_PER_TVec; + const int gamma_offset_in_group = (current_load_start_idx % s_group_stride) / T_PER_TVec; + TVec gamma_val = gamma_TVec_ptr[gamma_offset_in_group]; + TVec beta_val = beta_TVec_ptr[gamma_offset_in_group]; + T* gamma_val_ptr = (T*)(&gamma_val); + T* beta_val_ptr = (T*)(&beta_val); + TVec tmp_vec = input_TVec_ptr[offset_in_group]; + T* tmp_vec_ptr = (T*)(&tmp_vec); + TVec output_tmp_vec; + T* output_tmp_vec_ptr = (T*)(&output_tmp_vec); +#pragma unroll + for (int j = 0; j < T_PER_TVec; j++) { + float tmp = + (static_cast(tmp_vec_ptr[j]) - s_mean) * s_variance * static_cast(gamma_val_ptr[j]) + + static_cast(beta_val_ptr[j]); + if (sizeof(T) == sizeof(half)) { + output_tmp_vec_ptr[j] = T(__float2half_rn(tmp)); + } + else { + output_tmp_vec_ptr[j] = T(tmp); + } + } + output_TVec_ptr[offset_in_group] = output_tmp_vec; + } + } +} + +//ref_input & ref_output should be [N, H, W, C] +//ref_gamma & ref_beta should be [1, 1, 1, C] +template +void groupnorm(cutlass::Tensor4DCoord input_size, + const int num_groups, + const float eps, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int N = input_size.n(); + const int H = input_size.h(); + const int W = input_size.w(); + const int C = input_size.c(); + if (C % num_groups != 0){ + printf("[ERROR] C should be a multiple of num_groups.\n"); + } + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + + const int dim0 = N; + const int last_dim = C; + const int prod_dim1_to_last_dim = H*W*C; + const int s_reduce_elements = prod_dim1_to_last_dim / num_groups; + const int s_group_stride = last_dim / num_groups; + dim3 grid(num_groups, dim0); + int threadblock_size = 32; + if (s_group_stride % 2 == 0) { + const int T_PER_TVec = 2; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + // for small s_reduce_elements, specific case for H=W=22, C=1280, num_groups=32; + // the size of grid & block may have better choice for different cases. + // ensure shared memory is smaller than 48KB + if (std::is_same::value){ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + else{ + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + } + else { + const int T_PER_TVec = 1; + while (threadblock_size < 1024) { + if (s_reduce_elements / T_PER_TVec / threadblock_size <= 8) + break; + threadblock_size *= 2; + } + dim3 block(threadblock_size); + const int TVec_PER_THREAD = (s_reduce_elements / T_PER_TVec + threadblock_size - 1) / threadblock_size; + const int shm_size = T_PER_TVec * TVec_PER_THREAD * threadblock_size * sizeof(T); + if (shm_size < 48 * 1024) { + groupnorm_twopass_store_locally<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + else { + groupnorm_twopass_multiple_load<<>>( + output, input, gamma, beta, num_groups, prod_dim1_to_last_dim, last_dim, eps, TVec_PER_THREAD); + } + } + +} + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0fcbf5cb0f4bf3152a708c6e3845e89fd214cfac --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_layernorm.h @@ -0,0 +1,644 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do layernorm on a device memory tensor with RowMajor layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do layernorm on a device memory tensor with RowMajor layout. + * \tparam T: data type + */ +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream); + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + T local_val[ITEM_PER_THREAD]; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + const T zero = T(0.0f); + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + local_val[i] = index < n ? input[index] : zero; + local_sums[0] += static_cast(local_val[i]); + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n){ + const float tmp = static_cast(local_val[i]) - s_mean; + local_sums[0] += tmp * tmp; + } + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma unroll + for (int i = 0 ; i < ITEM_PER_THREAD ; i++){ + int index = tid + i*bdimx; + if (index < n) { + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + output[index] = T((static_cast(local_val[i]) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T2 local_val[ITEM_PER_THREAD]; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + const T2 zero = {T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_2 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const float2 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_2){ + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*4 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_stored_locally_e4(T4* output, + const T4* input, + const T4* gamma, + const T4* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + T4 local_val[ITEM_PER_THREAD]; + const int n_4 = n / 4; + int offset = m_idx * n_4; + input += offset; + output += offset; + + const T4 zero = {T(0.0f), T(0.0f), T(0.0f), T(0.0f)}; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + local_val[i] = index < n_4 ? input[index] : zero; + local_sums[0] += static_cast(local_val[i].x) + static_cast(local_val[i].y) + + static_cast(local_val[i].z) + static_cast(local_val[i].w); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const float4 tmp = {static_cast(local_val[i].x) - s_mean, + static_cast(local_val[i].y) - s_mean, + static_cast(local_val[i].z) - s_mean, + static_cast(local_val[i].w) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y + tmp.z * tmp.z + tmp.w * tmp.w; + } + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + #pragma UNROLL + for (int i = 0; i < ITEM_PER_THREAD; i += 1) { + const int index = i*bdimx + tid; + if (index < n_4){ + const T4 gamma_val = gamma[index]; + const T4 beta_val = beta[index]; + T4 tmp; + tmp.x = T((static_cast(local_val[i].x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val[i].y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + tmp.z = T((static_cast(local_val[i].z) - s_mean)*s_variance*static_cast(gamma_val.z) + static_cast(beta_val.z)); + tmp.w = T((static_cast(local_val[i].w) - s_mean)*s_variance*static_cast(gamma_val.w) + static_cast(beta_val.w)); + output[index] = tmp; + } + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with n elements ; each thread deals with ITEM_PER_THREAD elements +*/ +template +__global__ void layernorm_twoPassAlgo_e1(T* output, + const T* input, + const T* gamma, + const T* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_val = local_val - s_mean; + local_sums[0] += local_val * local_val; + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T gamma_val = gamma[index]; + const T beta_val = beta[index]; + const T local_val = input[index]; + output[index] = T((static_cast(local_val) - s_mean) * s_variance * static_cast(gamma_val) + static_cast(beta_val)); + } +} + +/** + * output [m, n] row-major + * input [m, n] row-major + * gamma [n] + * beta [n] + * grid(m) + * block(block_size) -- each block deals with block_size*ITEM_PER_THREAD*2 elements; +*/ +template +__global__ void layernorm_twoPassAlgo_e2(T2* output, + const T2* input, + const T2* gamma, + const T2* beta, + const int m, + const int n) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean, s_variance; + float local_sums[1] = {0.0f}; + const int n_2 = n / 2; + int offset = m_idx * n_2; + input += offset; + output += offset; + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + local_sums[0] += static_cast(local_val.x) + static_cast(local_val.y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = local_sums[0] / n; + } + __syncthreads(); + + local_sums[0] = 0.0f; + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const float2 tmp = {static_cast(local_val.x) - s_mean, + static_cast(local_val.y) - s_mean}; + local_sums[0] += tmp.x * tmp.x + tmp.y * tmp.y; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_variance = rsqrtf(local_sums[0] / n + 1e-5); + } + __syncthreads(); + + for (int index = tid; index < n_2; index += bdimx) { + const T2 local_val = input[index]; + const T2 gamma_val = gamma[index]; + const T2 beta_val = beta[index]; + T2 tmp; + tmp.x = T((static_cast(local_val.x) - s_mean)*s_variance*static_cast(gamma_val.x) + static_cast(beta_val.x)); + tmp.y = T((static_cast(local_val.y) - s_mean)*s_variance*static_cast(gamma_val.y) + static_cast(beta_val.y)); + output[index] = tmp; + } +} + +template +void layernorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_gamma, + TensorRef ref_beta, + cudaStream_t stream){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* gamma = ref_gamma.data(); + const T* beta = ref_beta.data(); + dim3 grid(m); + dim3 block((n + 31)/32*32); + if (block.x > 1024){ + block.x = 1024; + } + // TODO : There should be better configs for different cases, we only use several samples to show how to use here + // TODO : using registers to store values locally can reduce the loads from global memory and speedup the kernels. + if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) { + block.x = (n/4 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (float4*)output, + (const float4*)input, + (const float4*)gamma, + (const float4*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e4<<>>( + (half4*)output, + (const half4*)input, + (const half4*)gamma, + (const half4*)beta, + m, + n); + } + } //if ((n % 4 == 0) && (n >= 128) && (n <= 4096)) + else if (n % 2 == 0) { + if (n / 2 <= 1024) { + block.x = (n/2 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } //if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n / 2 <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/ 16 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + if (std::is_same::value) { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (float2*)output, + (const float2*)input, + (const float2*)gamma, + (const float2*)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_stored_locally_e2<<>>( + (half2*)output, + (const half2*)input, + (const half2*)gamma, + (const half2*)beta, + m, + n); + } + } // if (n <= 32768) + else { + if (block.x > 512) + block.x = 512; + if (std::is_same::value) { + layernorm_twoPassAlgo_e2<<>>( + (float2 *)output, + (const float2 *)input, + (const float2 *)gamma, + (const float2 *)beta, + m, + n); + } // if (std::is_same::value) + else { + layernorm_twoPassAlgo_e2<<>>( + (half2 *)output, + (const half2 *)input, + (const half2 *)gamma, + (const half2 *)beta, + m, + n); + } + } + } // if (n % 2 == 0) + else { + if (n <= 1024) { + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 1024) + else if (n <= 8192) { + block.x = ((n + 7)/8 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 8192) + else if (n <= 16384) { + block.x = ((n + 15)/16 + 32)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 16384) + else if (n <= 32768) { + block.x = ((n + 31)/32 + 31)/32*32; + layernorm_twoPassAlgo_stored_locally_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } // if (n <= 32768) + else{ + if (block.x > 512) { + block.x = 512; + } + layernorm_twoPassAlgo_e1<<>>( + output, + input, + gamma, + beta, + m, + n); + } + } +} + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h new file mode 100644 index 0000000000000000000000000000000000000000..44f6a467a5d0938289e4bc127cddc13b9aeabdf3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_memory.h @@ -0,0 +1,375 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ interface to CUDA device memory management functions. + */ + +#include +#include + +#include "cutlass/platform/platform.h" +#include "cutlass/numeric_types.h" +#include "cutlass/trace.h" +#include "exceptions.h" + +namespace cutlass { +namespace device_memory { + +/****************************************************************************** + * Allocation lifetime + ******************************************************************************/ + +/// Allocate a buffer of \p count elements of type \p T on the current CUDA device +template +T* allocate(size_t count = 1) { + + T* ptr = 0; + size_t bytes = count * sizeof_bits::value / 8; + + cudaError_t cuda_error = cudaMalloc((void**)&ptr, bytes); + + if (cuda_error != cudaSuccess) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 0) + std::ostringstream os; + os << "cutlass::device_memory::allocate: cudaMalloc failed: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); +#endif + throw cuda_exception("Failed to allocate memory", cuda_error); + } +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + else { + std::ostringstream os; + os << "cutlass::device_memory::allocate: Successful cudaMalloc: bytes=" << bytes; + CUTLASS_TRACE_HOST(os.str()); + } +#endif + + return ptr; +} + +/// Free the buffer pointed to by \p ptr +template +void free(T* ptr) { + if (ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + throw cuda_exception("Failed to free device memory", cuda_error); + } + } +} + +/****************************************************************************** + * Data movement + ******************************************************************************/ + +template +void copy(T* dst, T const* src, size_t count, cudaMemcpyKind kind) { + size_t bytes = count * sizeof_bits::value / 8; + if (bytes == 0 && count > 0) { + bytes = 1; + } + cudaError_t cuda_error = (cudaMemcpy(dst, src, bytes, kind)); + if (cuda_error != cudaSuccess) { + std::ostringstream os; + os << "cutlass::device_memory::copy: cudaMemcpy() failed: " + << "dst=" << dst << ", src=" << src + << ", bytes=" << bytes << ", count=" << count; + if (kind == cudaMemcpyHostToDevice) { + os << ", kind=cudaMemcpyHostToDevice"; + } + else if (kind == cudaMemcpyDeviceToHost) { + os << ", kind=cudaMemcpyDeviceToHost"; + } + else if (kind == cudaMemcpyDeviceToDevice) { + os << ", kind=cudaMemcpyDeviceToDevice"; + } + else if (kind == cudaMemcpyHostToHost) { + os << ", kind=cudaMemcpyHostToHost"; + } + else if (kind == cudaMemcpyDefault) { + os << ", kind=cudaMemcpyDefault"; + } + else { + os << ", kind=Unknown"; + } + os << ", error: " << cudaGetErrorString(cuda_error); + + throw cuda_exception(os.str().c_str(), cuda_error); + } +} + +template +void copy_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToDevice); +} + +template +void copy_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToHost); +} + +template +void copy_device_to_device(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyDeviceToDevice); +} + +template +void copy_host_to_host(T* dst, T const* src, size_t count = 1) { + copy(dst, src, count, cudaMemcpyHostToHost); +} + +/// Copies elements from device memory to host-side range +template +void insert_to_host(OutputIterator begin, OutputIterator end, T const* device_begin) { + size_t elements = end - begin; + copy_to_host(&*begin, device_begin, elements); +} + +/// Copies elements to device memory from host-side range +template +void insert_to_device(T* device_begin, InputIterator begin, InputIterator end) { + size_t elements = end - begin; + copy_to_device(device_begin, &*begin, elements); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DeviceAllocation { +public: + + /// Delete functor for CUDA device memory + struct deleter { + void operator()(T* ptr) { + cudaError_t cuda_error = (cudaFree(ptr)); + if (cuda_error != cudaSuccess) { + // noexcept + // throw cuda_exception("cudaFree() failed", cuda_error); + return; + } + } + }; + +public: + // + // Data members + // + + /// Number of elements of T allocated on the current CUDA device + size_t capacity; + + /// Smart pointer + platform::unique_ptr smart_ptr; + +public: + + // + // Static methods + // + + /// Static member to compute the number of bytes needed for a given number of elements + static size_t bytes(size_t elements) { + if (sizeof_bits::value < 8) { + size_t const kElementsPerByte = 8 / sizeof_bits::value; + return elements / kElementsPerByte; + } + else { + size_t const kBytesPerElement = sizeof_bits::value / 8; + return elements * kBytesPerElement; + } + } + +public: + + // + // Methods + // + + /// Constructor: allocates no memory + DeviceAllocation() : capacity(0) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device + DeviceAllocation(size_t _capacity) : + smart_ptr(device_memory::allocate(_capacity)), capacity(_capacity) {} + + /// Constructor: allocates \p capacity elements on the current CUDA device taking ownership of the allocation + DeviceAllocation(T *ptr, size_t _capacity) : smart_ptr(ptr), capacity(_capacity) {} + + /// Copy constructor + DeviceAllocation(DeviceAllocation const &p): + smart_ptr(device_memory::allocate(p.capacity)), capacity(p.capacity) { + + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + } + + /// Move constructor + DeviceAllocation(DeviceAllocation &&p): capacity(0) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + } + + /// Destructor + ~DeviceAllocation() { reset(); } + + /// Returns a pointer to the managed object + T* get() const { return smart_ptr.get(); } + + /// Releases the ownership of the managed object (without deleting) and resets capacity to zero + T* release() { + capacity = 0; + return smart_ptr.release(); + } + + /// Deletes the managed object and resets capacity to zero + void reset() { + capacity = 0; + smart_ptr.reset(); + } + + /// Deletes managed object, if owned, and allocates a new object + void reset(size_t _capacity) { + reset(device_memory::allocate(_capacity), _capacity); + } + + /// Deletes managed object, if owned, and replaces its reference with a given pointer and capacity + void reset(T* _ptr, size_t _capacity) { + smart_ptr.reset(_ptr); + capacity = _capacity; + } + + /// Allocates a new buffer and copies the old buffer into it. The old buffer is then released. + void reallocate(size_t new_capacity) { + + platform::unique_ptr new_allocation(device_memory::allocate(new_capacity)); + + device_memory::copy_device_to_device( + new_allocation.get(), + smart_ptr.get(), + std::min(new_capacity, capacity)); + + std::swap(smart_ptr, new_allocation); + std::swap(new_capacity, capacity); + } + + /// Returns the number of elements + size_t size() const { + return capacity; + } + + /// Returns the number of bytes needed to store the allocation + size_t bytes() const { + return bytes(capacity); + } + + /// Returns a pointer to the object owned by *this + T* operator->() const { return smart_ptr.get(); } + + /// Returns the deleter object which would be used for destruction of the managed object. + deleter& get_deleter() { return smart_ptr.get_deleter(); } + + /// Returns the deleter object which would be used for destruction of the managed object (const) + const deleter& get_deleter() const { return smart_ptr.get_deleter(); } + + /// Copies a device-side memory allocation + DeviceAllocation & operator=(DeviceAllocation const &p) { + if (capacity != p.capacity) { + smart_ptr.reset(device_memory::allocate(p.capacity)); + capacity = p.capacity; + } + device_memory::copy_device_to_device(smart_ptr.get(), p.get(), capacity); + return *this; + } + + /// Move assignment + DeviceAllocation & operator=(DeviceAllocation && p) { + std::swap(smart_ptr, p.smart_ptr); + std::swap(capacity, p.capacity); + return *this; + } + + /// Copies the entire allocation from another location in device memory. + void copy_from_device(T const *ptr) const { + copy_from_device(ptr, capacity); + } + + /// Copies a given number of elements from device memory + void copy_from_device(T const *ptr, size_t elements) const { + device_memory::copy_device_to_device(get(), ptr, elements); + } + + void copy_to_device(T *ptr) const { + copy_to_device(ptr, capacity); + } + + void copy_to_device(T *ptr, size_t elements) const { + device_memory::copy_device_to_device(ptr, get(), elements); + } + + void copy_from_host(T const *ptr) const { + copy_from_host(ptr, capacity); + } + + void copy_from_host(T const *ptr, size_t elements) const { + device_memory::copy_to_device(get(), ptr, elements); + } + + void copy_to_host(T *ptr) const { + copy_to_host(ptr, capacity); + } + + void copy_to_host(T *ptr, size_t elements) const { + device_memory::copy_to_host(ptr, get(), elements); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace device_memory { + +/// Device allocation abstraction that tracks size and capacity +template +using allocation = cutlass::DeviceAllocation; + +} // namespace device_memory + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h new file mode 100644 index 0000000000000000000000000000000000000000..8e38029951d27c0be8da059b59d2a83fe2762ef1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nchw_to_nhwc.h @@ -0,0 +1,141 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NCHW layout to NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NCHW layout to NHWC layout. + * \tparam T: data type + */ +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + +template +__global__ void nchw_to_nhwc_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + const int hw = h*w; + const int chw = c*hw; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t ci0 = blockIdx.y * 32; + const int32_t hwi0 = blockIdx.x * 32; + + const size_t input_idx = ni * chw + (ci0 + wid) * hw + hwi0; + const T *A = input + input_idx; + if (hwi0 + lid < hw) { + const int lid_x_33 = lid * 33; + if ((ci0 + 32) <= c) { + int ci = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + ci] = A[lid]; + A = &A[8 * hw]; + ci += 8; + } + } else { + for (int ci = wid; ci < 32; ci += 8) { + if ((ci + ci0) < c) { + shbuf[lid_x_33 + ci] = A[lid]; + } + A = &A[8 * hw]; + } + } + } + __syncthreads(); + + const int32_t ciOut = ci0 + lid; + output = &output[ni * chw + ciOut]; + if (ciOut < c) { + if (hwi0 + 32 < hw) { + int hwI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + hwI += 8; + } + } else { + for (int hwI = wid; hwI < 32; hwI += 8) { + if (hwi0 + hwI < hw) { + output[(hwi0 + hwI) * c] = shbuf[(hwI)*33 + lid]; + } + } + } + } +} + +template +void nchw_to_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.h() && + input_tensor_size.h() == output_tensor_size.w() && + input_tensor_size.w() == output_tensor_size.c()); + + int n = output_tensor_size.n(); + int h = output_tensor_size.h(); + int w = output_tensor_size.w(); + int c = output_tensor_size.c(); + + dim3 grid((h*w + 31)/32, (c + 31)/32, n); + dim3 block(32, 8); + nchw_to_nhwc_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); +} + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h new file mode 100644 index 0000000000000000000000000000000000000000..f58da62a35350b4a865f4521ec1cbb76ae87e874 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_padding.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels for padding in device memory with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface for padding in a device memory tensor with NHWC layout + * \tparam T: data type + */ +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_padding_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const int32_t c_in, + const int32_t c_out, + const T zero, + const T *input, + T *output){ + + const int32_t idx_jump = blockDim.x * gridDim.x; + const int32_t total_elements = n * h * w * c_out; + + int32_t c_idx, w_idx, h_idx, n_idx, resudial; + + T value; + for (int32_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < total_elements; idx += idx_jump) { + + c_idx = idx%c_out; + if (c_idx >= c_in){ + value = zero; + } + else{ + resudial = idx/c_out; + w_idx = resudial%w; + resudial = resudial/w; + h_idx = resudial%h; + n_idx = resudial/h; + resudial = ((n_idx * h + h_idx) * w + w_idx) * c_in + c_idx; + value = input[resudial]; + } + output[idx] = value; + } +} + + +// fast kernel for c_in = 3 & c_out = 4 +template +__global__ void nhwc_padding_channel_3To4_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 256; + const int lower_bound = max_output_element < output_offset + 256 ? max_output_element : output_offset + 256; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + j*3*element_in_Tio/4; + Telement array[element_in_Tio]; + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((k+1)%4 == 0) ? zero_element : shm_element[(k > 3) ? (k - 1) : k]; + output[i] = *((const Tio *)array); + } +} + +// fast kernel for c_in = 3 & c_out = 8 +template +__global__ void nhwc_padding_channel_3To8_kernel(const int32_t n, + const int32_t h, + const int32_t w, + const Tio *input, + Tio *output, + const int32_t max_output_element, + const int32_t max_input_element, + const Tio zero_io, + const Telement zero_element){ + __shared__ Tio shm[192]; + const int tidx = blockIdx.x * 192 + threadIdx.x; + const int threadidx = threadIdx.x; + + shm[threadIdx.x] = tidx >= max_input_element ? zero_io : input[tidx]; + __syncthreads(); + + const int output_offset = blockIdx.x * 512; + const int lower_bound = max_output_element < output_offset + 512 ? max_output_element : output_offset + 512; + for (int i = output_offset + threadidx, j = threadidx ; i < lower_bound ; i+=192, j+=192) + { + const Telement* shm_element = (const Telement*)shm + (element_in_Tio == 4 ? j/2 : j)*3; + Telement array[element_in_Tio]; + //float + if (element_in_Tio == 4){ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = ((j % 2) == 1) ? zero_element : ((k >= 3) ? zero_element : shm_element[k]); + } + //half + else{ + CUTLASS_PRAGMA_UNROLL + for (int k = 0 ; k < element_in_Tio ; k++) + array[k] = (k >= 3) ? zero_element : shm_element[k]; + } + output[i] = *((const Tio *)array); + } +} + +template +void nhwc_padding(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream){ + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.h() && + input_tensor_size.w() == output_tensor_size.w() && + input_tensor_size.c() <= output_tensor_size.c()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c_in = input_tensor_size.c(); + int c_out = output_tensor_size.c(); + + //case 1 : channel == 3 padding to 4 or 8 + if ((c_out == 4 || c_out == 8) && c_in == 3 && (n*h*w % 8 == 0)){ + dim3 block(192); + const int nhw = n*h*w; + const int nhwc = nhw*c_in; + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const int element_in_Tio = 8; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const int4 zero_io = {0, 0, 0, 0}; + const half_t zero_element = static_cast(0.0f); + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const int4 *)ref_input.data(), + (int4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + //for float + else{ + const int element_in_Tio = 4; + const int max_input_element = nhwc/element_in_Tio; + const int max_output_element = nhw*c_out/element_in_Tio; + const float4 zero_io = {0.0f, 0.0f, 0.0f, 0.0f}; + const float zero_element = 0.0f; + dim3 grid((nhwc + 192*element_in_Tio - 1)/(192*element_in_Tio)); + if (c_out == 4){ + nhwc_padding_channel_3To4_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + else if (c_out == 8){ + nhwc_padding_channel_3To8_kernel<<>> + (n, h, w, + (const float4 *)ref_input.data(), + (float4 *)ref_output.data(), + max_output_element, + max_input_element, + zero_io, + zero_element); + } + } + } + //case 2 : even channel + else if ((c_out % 2) == 0 && (c_in % 2) == 0){ + int32_t total_elements = n * h * w * c_out / 2; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + //for half_t + if (cutlass::sizeof_bits::value == 16){ + const __half2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const __half2*)ref_input.data(), (__half2*)ref_output.data()); + } + //for float + else{ + const float2 zero = {0.0f, 0.0f}; + nhwc_padding_kernel<<>>(n, h, w, c_in/2, c_out/2, zero, (const float2*)ref_input.data(), (float2*)ref_output.data()); + } + } + //case 3 : odd channel + else{ + int32_t total_elements = n * h * w * c_out; + int block_size = 256; + dim3 grid((total_elements + 255)/256); + dim3 block(block_size); + const T zero = static_cast(0.0f); + nhwc_padding_kernel<<>>(n, h, w, c_in, c_out, zero, ref_input.data(), ref_output.data()); + } +} + + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h new file mode 100644 index 0000000000000000000000000000000000000000..5633456c1412ff41366ec4c6ec5c3e6e3a2d6c19 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_pooling.h @@ -0,0 +1,573 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to do avg/max pooling on a device memory tensor with NHWC layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "device_utils.h" +#include + +namespace cutlass { + +/** \brief interface to do avg/max pooling on a device memory tensor with NHWC layout. + * \tparam T: data type + */ +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::MatrixCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream); + +/** get the output size of pooling + */ +inline int getOutputSize(int H_W, int padding, int kernel_size, int stride) +{ + return (H_W + 2 * padding - kernel_size) / stride + 1; +} + +/** + * input is [N, H, W, C] + * assume stride == kernel_size + * output_h = (H + 2*padding_H - kernel_H)/stride_H + * output_w = (W + 2*padding_W - kernel_W)/stride_W + * output is [N, output_h, output_w, C] + * grid(N, output_h, output_w) + * block(min(C, 256)) : + * each block deals with C elements of output when each thread deals with ((C + 255)/256 element of output) +*/ +template +__global__ void pooling_nhwc_element1_kernel(T* output, + const T* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float pooling; + if (IS_AVG_POOLING){ + pooling = 0.0f; + } + else{ + pooling = -FLT_MAX; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const float tmp = static_cast(input[idx + c_idx]); + if (IS_AVG_POOLING){ + pooling = pooling + tmp; + } + else{ + pooling = pooling > tmp ? pooling : tmp; + } + } + } + + T output_val; + if (IS_AVG_POOLING){ + output_val = T(pooling/kernel_size2); + } + else{ + output_val = T(pooling); + } + output[c_idx] = output_val; + } +} + +template +__global__ void pooling_nhwc_element2_kernel(T2* output, + const T2* input, + const int N, + const int H, + const int W, + const int C, + const int output_H, + const int output_W, + const int kernel_H, + const int kernel_W, + const int stride_H, + const int stride_W, + const int padding_H, + const int padding_W) +{ + const int tid = threadIdx.x; + const int n_idx = blockIdx.x; + const int output_h_idx = blockIdx.y; + const int output_w_idx = blockIdx.z; + + int h_start_idx = output_h_idx * stride_H - padding_H; + int h_end_idx = h_start_idx + kernel_H; + h_start_idx = (h_start_idx < 0) ? 0 : h_start_idx; + h_end_idx = h_end_idx > H ? H : h_end_idx; + + int w_start_idx = output_w_idx * stride_W - padding_W; + int w_end_idx = w_start_idx + kernel_W; + w_start_idx = (w_start_idx < 0) ? 0 : w_start_idx; + w_end_idx = w_end_idx > W ? W : w_end_idx; + + input += n_idx * H * W * C; + output += ((n_idx * output_H + output_h_idx) * output_W + output_w_idx) * C; + const int kernel_size2 = kernel_H * kernel_W; + for (int c_idx = tid; c_idx < C; c_idx += blockDim.x) { + float2 pooling; + if (IS_AVG_POOLING) { + pooling = {0.0f, 0.0f}; + } + else { + pooling = {-FLT_MAX, -FLT_MAX}; + } + for (int h = h_start_idx; h < h_end_idx; h++) { + for (int w = w_start_idx; w < w_end_idx; w++) { + const int idx = (h * W + w) * C; + const T2 tmp = input[idx + c_idx]; + const float2 tmp_flt2 = {static_cast(tmp.x), static_cast(tmp.y)}; + if (IS_AVG_POOLING) { + pooling.x += tmp_flt2.x; + pooling.y += tmp_flt2.y; + } + else { + pooling.x = pooling.x > tmp_flt2.x ? pooling.x : tmp_flt2.x; + pooling.y = pooling.y > tmp_flt2.y ? pooling.y : tmp_flt2.y; + } + } + } + + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling.x/kernel_size2); + output_val.y = T(pooling.y/kernel_size2); + } + else { + output_val.x = T(pooling.x); + output_val.y = T(pooling.y); + } + output[c_idx] = output_val; + } +} + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C, N) + * block(block_size) -- each block deals with H*W/block_size elements; +*/ +template +__global__ void pooling_nxhTo1x1_element1_kernel( + T* output, const T* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[1]; + if (IS_AVG_POOLING) { + pooling[0] = 0.0f; + } + else { + pooling[0] = -FLT_MAX; + } + const size_t input_offset = n_idx * HW * C + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + float val = static_cast(input[index * C]); + if (IS_AVG_POOLING) { + pooling[0] += val; + } + else { + pooling[0] = pooling[0] > val ? pooling[0] : val; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T output_val; + if (IS_AVG_POOLING) { + output_val = T(pooling[0] / HW); + } + else { + output_val = T(pooling[0]); + } + output[0] = output_val; + } +} + + +/** + * output [N, 1, 1, C] + * input [N, H, W, C] + * grid(C/2, N) + * block(block_size) -- each thread deals with H*W/block_size * 2 elements; +*/ +template +__global__ void pooling_nxhTo1x1_element2_kernel( + T2* output, const T2* input, const int N, const int HW, const int C) +{ + const int c_idx = blockIdx.x; + const int n_idx = blockIdx.y; + float pooling[2]; + if (IS_AVG_POOLING) { + pooling[0] = pooling[1] = 0.0f; + } + else { + pooling[0] = pooling[1] = -FLT_MAX; + } + const int C_2 = C / 2; + const size_t input_offset = n_idx * HW * C_2 + c_idx; + input += input_offset; + const size_t output_offset = n_idx * C_2 + c_idx; + output += output_offset; + int tid = threadIdx.x; + + for (int index = tid; index < HW; index += blockDim.x) { + T2 val = input[index * C_2]; + float2 val_flt2 = {static_cast(val.x), static_cast(val.y)}; + if (IS_AVG_POOLING) { + pooling[0] += val_flt2.x; + pooling[1] += val_flt2.y; + } + else { + pooling[0] = pooling[0] > val_flt2.x ? pooling[0] : val_flt2.x; + pooling[1] = pooling[1] > val_flt2.y ? pooling[1] : val_flt2.y; + } + } + if (blockDim.x <= 32) { + if (IS_AVG_POOLING) { + warpReduceSum(pooling); + } + else { + warpReduceMax(pooling); + } + } + else { + if (IS_AVG_POOLING) { + blockReduceSum(pooling); + } + else { + blockReduceMax(pooling); + } + } + __syncthreads(); + if (threadIdx.x == 0) { + T2 output_val; + if (IS_AVG_POOLING) { + output_val.x = T(pooling[0] / HW); + output_val.y = T(pooling[1] / HW); + } + else { + output_val.x = T(pooling[0]); + output_val.y = T(pooling[1]); + } + output[0] = output_val; + } +} + +template +void pooling_nhwc(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord filter_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + cutlass::Tensor4DCoord padding, + cutlass::MatrixCoord stride, + TensorRef ref_input, + TensorRef ref_output, + int poolingType, //0 for avg pooling ; 1 for max pooling + cudaStream_t stream) { + + assert(input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.c() == output_tensor_size.c()); + + const int N = input_tensor_size.n(); + const int H = input_tensor_size.h(); + const int W = input_tensor_size.w(); + const int C = input_tensor_size.c(); + const int padding_H = padding.h(); + const int padding_W = padding.w(); + const int kernel_H = filter_tensor_size.h(); + const int kernel_W = filter_tensor_size.w(); + const int stride_H = stride.row(); + const int stride_W = stride.column(); + + const int output_H = getOutputSize(H, padding_H, kernel_H, stride_H); + const int output_W = getOutputSize(W, padding_W, kernel_W, stride_W); + + assert(output_tensor_size.h() == output_H && + output_tensor_size.w() == output_W); + + if (C % 2 != 0) { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } // if (poolingType == 0) + else { + pooling_nxhTo1x1_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H*W, + C); + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C < block.x) { + block.x = C; + } + if (poolingType == 0) { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (poolingType == 0) + else { + pooling_nhwc_element1_kernel<<>>( + ref_output.data(), + ref_input.data(), + N, + H, + W, + C, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } // if (C % 2 != 0)) + else { + if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) { + dim3 grid(C/2, N); + dim3 block(256); + if (H*W < block.x){ + block.x = (H*W + 31)/32*32; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nxhTo1x1_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H*W, + C); + } // if (std::is_same::value) + else { + pooling_nxhTo1x1_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H*W, + C); + } + } + } // if ((H == kernel_H && padding_H == 0) && (W == kernel_W && padding_W == 0)) + else { + dim3 grid(N, output_H, output_W); + dim3 block(256); + if (C/2 < block.x) { + block.x = C/2; + } + if (poolingType == 0) { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } // if (poolingType == 0) + else { + if (std::is_same::value) { + pooling_nhwc_element2_kernel<<>>( + (float2*)(ref_output.data()), + (const float2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } // if (std::is_same::value) + else { + pooling_nhwc_element2_kernel<<>>( + (half2*)(ref_output.data()), + (const half2*)(ref_input.data()), + N, + H, + W, + C/2, + output_H, + output_W, + kernel_H, + kernel_W, + stride_H, + stride_W, + padding_H, + padding_W); + } + } + } + } +} + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h new file mode 100644 index 0000000000000000000000000000000000000000..babfecd39205ebff39794133868e4a95b7e9525c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_nhwc_to_nchw.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief cuda kernels to transform a device memory tensor from NHWC layout to NCHW layout. + */ + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +/** \brief interface to transform a device memory tensor from NHWC layout to NCHW layout. + * \tparam T: data type + */ +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream); + + +template +__global__ void nhwc_to_nchw_kernel(T *output, + const T *input, + const int n, + const int h, + const int w, + const int c) { + + const int hw = h*w; + const int hwc = hw*c; + __shared__ T shbuf[32 * (32 + 1)]; + const int32_t tid = threadIdx.y*blockDim.x + threadIdx.x; + const int32_t wid = tid / 32; + const int32_t lid = tid % 32; + const int32_t ni = blockIdx.z; + const int32_t hwi0 = blockIdx.y * 32; + const int32_t ci0 = blockIdx.x * 32; + + const size_t input_idx = ni * hwc + (hwi0 + wid) * c + ci0; + const T *A = input + input_idx; + if (ci0 + lid < c) { + const int lid_x_33 = lid * 33; + if ((hwi0 + 32) <= hw) { + int hwi = wid; // between 0 and 7 + CUTLASS_PRAGMA_UNROLL + for (int cLoopIdx = 0; cLoopIdx < 4; cLoopIdx++) { + shbuf[lid_x_33 + hwi] = A[lid]; + A = &A[8 * c]; + hwi += 8; + } + } else { + for (int hwi = wid; hwi < 32; hwi += 8) { + if ((hwi + hwi0) < hw) { + shbuf[lid_x_33 + hwi] = A[lid]; + } + A = &A[8 * c]; + } + } + } + __syncthreads(); + + const int32_t hwiOut = hwi0 + lid; + output = &output[ni * hwc + hwiOut]; + if (hwiOut < hw) { + if (ci0 + 32 < c) { + int cI = wid; + CUTLASS_PRAGMA_UNROLL + for (int hwLoopIdx = 0; hwLoopIdx < 4; ++hwLoopIdx) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + cI += 8; + } + } else { + for (int cI = wid; cI < 32; cI += 8) { + if (ci0 + cI < c) { + output[(ci0 + cI) * hw] = shbuf[(cI)*33 + lid]; + } + } + } + } +} + +template +void nhwc_to_nchw(cutlass::Tensor4DCoord input_tensor_size, + cutlass::Tensor4DCoord output_tensor_size, + TensorRef ref_input, + TensorRef ref_output, + cudaStream_t stream) { + + assert( + input_tensor_size.n() == output_tensor_size.n() && + input_tensor_size.h() == output_tensor_size.c() && + input_tensor_size.w() == output_tensor_size.h() && + input_tensor_size.c() == output_tensor_size.w()); + + int n = input_tensor_size.n(); + int h = input_tensor_size.h(); + int w = input_tensor_size.w(); + int c = input_tensor_size.c(); + + dim3 grid((c + 31)/32, (h*w + 31)/32, n); + dim3 block(32, 8); + nhwc_to_nchw_kernel<<>>(ref_output.data(), ref_input.data(), + n, h, w, c); + +} + +} //namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h new file mode 100644 index 0000000000000000000000000000000000000000..0d1b1af56e4463640edc3e9c82533baf815c9b27 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_rmsnorm.h @@ -0,0 +1,186 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/util/device_utils.h" +#include + +namespace cutlass { + +__global__ void rmsnorm_twoPassAlgo_e8(float4 *output, const float4 *input, + const float4 *weight, + const int m, const int n, float epsilon) { + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + const int n_8 = n / 8; + int offset = m_idx * n_8; + input += offset; + output += offset; + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const half2 *h1 = (half2 *)&local_val.x; + const half2 *h2 = (half2 *)&local_val.y; + const half2 *h3 = (half2 *)&local_val.z; + const half2 *h4 = (half2 *)&local_val.w; + local_sums[0] += static_cast(h1->x) * static_cast(h1->x) + + static_cast(h1->y) * static_cast(h1->y) + + static_cast(h2->x) * static_cast(h2->x) + + static_cast(h2->y) * static_cast(h2->y) + + static_cast(h3->x) * static_cast(h3->x) + + static_cast(h3->y) * static_cast(h3->y) + + static_cast(h4->x) * static_cast(h4->x) + + static_cast(h4->y) * static_cast(h4->y); + } + + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid; index < n_8; index += bdimx) { + const float4 local_val = input[index]; + const float4 weight_val = weight[index]; + + const half2 *l1 = (half2 *)&local_val.x; + const half2 *l2 = (half2 *)&local_val.y; + const half2 *l3 = (half2 *)&local_val.z; + const half2 *l4 = (half2 *)&local_val.w; + + const half2 *g1 = (half2 *)&weight_val.x; + const half2 *g2 = (half2 *)&weight_val.y; + const half2 *g3 = (half2 *)&weight_val.z; + const half2 *g4 = (half2 *)&weight_val.w; + + float4 tmp; + half2 *h1 = (half2 *)&tmp.x; + half2 *h2 = (half2 *)&tmp.y; + half2 *h3 = (half2 *)&tmp.z; + half2 *h4 = (half2 *)&tmp.w; + + h1->x = half(static_cast(l1->x) * s_mean * static_cast(g1->x)); + h1->y = half(static_cast(l1->y) * s_mean * static_cast(g1->y)); + h2->x = half(static_cast(l2->x) * s_mean * static_cast(g2->x)); + h2->y = half(static_cast(l2->y) * s_mean * static_cast(g2->y)); + h3->x = half(static_cast(l3->x) * s_mean * static_cast(g3->x)); + h3->y = half(static_cast(l3->y) * s_mean * static_cast(g3->y)); + h4->x = half(static_cast(l4->x) * s_mean * static_cast(g4->x)); + h4->y = half(static_cast(l4->y) * s_mean * static_cast(g4->y)); + + output[index] = tmp; + } +} + +template +__global__ void rmsnorm_twoPassAlgo_e1(T* output, + const T* input, + const T* weight, + const int m, const int n, + float epsilon) +{ + const int m_idx = blockIdx.x; + const int tid = threadIdx.x; + const int bdimx = blockDim.x; + __shared__ float s_mean; + float local_sums[1] = {0.0f}; + int offset = m_idx * n; + input += offset; + output += offset; + + for (int index = tid ; index < n ; index += bdimx){ + float local_val = static_cast(input[index]); + local_sums[0] += local_val * local_val; + } + if (blockDim.x <= 32) { + warpReduceSum(local_sums); + } + else { + blockReduceSum(local_sums); + } + if (threadIdx.x == 0) { + s_mean = rsqrtf(local_sums[0] / n + epsilon); + } + __syncthreads(); + + for (int index = tid ; index < n ; index += bdimx){ + const T weight_val = weight[index]; + const T local_val = input[index]; + output[index] = T(static_cast(local_val) * s_mean * static_cast(weight_val)); + } +} + +template +void rmsnorm(cutlass::MatrixCoord tensor_size, + TensorRef ref_output, + TensorRef ref_input, + TensorRef ref_weight, + cudaStream_t stream, float epsilon = 1e-5f){ + const int m = tensor_size.row(); + const int n = tensor_size.column(); + T* output = ref_output.data(); + const T* input = ref_input.data(); + const T* weight = ref_weight.data(); + dim3 grid(m); + + if (n % 8 == 0 && std::is_same::value) { + dim3 block(cutlass::platform::min(1024, (n / 8 + 31) / 32 * 32)); + + rmsnorm_twoPassAlgo_e8<<>>( + (float4 *)output, (const float4 *)input, (const float4 *)weight, m, n, epsilon); + } else { + dim3 block(cutlass::platform::min(1024, ((n + 31)/32 + 31)/32*32)); + + rmsnorm_twoPassAlgo_e1<<>>( + output, input, weight, m, n, epsilon); + } + + auto result = cudaGetLastError(); + if (result != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(result) << std::endl; + abort(); + } +} + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9747d50975d7d35df287f6b056aedc489adb317c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/device_utils.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief utils code for device cutlass code +*/ + +#pragma once + +#include +#include +#define FINAL_MASK 0xffffffff + +struct half4 { + half x, y, z, w; +}; + +template +__inline__ __device__ T warpReduceSum(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] += __shfl_xor_sync(FINAL_MASK, val[i], mask, 32); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceSum(T* val) +{ + __shared__ T shared[NUM][33]; + int lane = threadIdx.x & 0x1f; + int wid = threadIdx.x >> 5; + + warpReduceSum(val); + + if (lane == 0) { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[i][wid] = val[i]; + } + } + + __syncthreads(); + + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[i][lane] : (T)(0.0f); + } + warpReduceSum(val); + return (T)0.0f; +} + +template +__inline__ __device__ T warpReduceMax(T* val) +{ +#pragma unroll + for (int i = 0; i < NUM; i++) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) + val[i] = max(val[i], __shfl_xor_sync(FINAL_MASK, val[i], mask, 32)); + } + return (T)(0.0f); +} + +template +__inline__ __device__ T blockReduceMax(T* val) +{ + static __shared__ T shared[32][NUM]; + int lane = threadIdx.x & 0x1f; // in-warp idx + int wid = threadIdx.x >> 5; // warp idx + + warpReduceMax(val); // get maxx in each warp + + if (lane == 0) // record in-warp maxx by warp Idx + { +#pragma unroll + for (int i = 0; i < NUM; i++) { + shared[wid][i] = val[i]; + } + } + + __syncthreads(); + + // Modify from blockDim.x << 5 to blockDim.x / 32. to prevent + // blockDim.x is not divided by 32 + bool is_mask = threadIdx.x < (blockDim.x / 32.f); +#pragma unroll + for (int i = 0; i < NUM; i++) { + val[i] = is_mask ? shared[lane][i] : (T)(-FLT_MAX); + } + warpReduceMax(val); + + return (T)0.0f; +} + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h new file mode 100644 index 0000000000000000000000000000000000000000..6565aba9607ad68defacb6e98d9f9bbc944cd48d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/distribution.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief This header contains a class to parametrize a statistical distribution function. +*/ + +#include + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Distribution type +struct Distribution { + /// Variant types + enum Kind { Invalid, Uniform, Gaussian, Identity, Sequential, AllZeros, AllOnes }; + + /// Distribution state + union { + /// Uniform distribution + struct { + double min; + double max; + // Percent elements set to NaN + double pnan; + } uniform; + + /// Gaussian distribution + struct { + double mean; + double stddev; + double pnz; + double pnzA; + double pnzB; + double pnzC; + } gaussian; + + /// Elements are linear combination of row and column index + struct { + double start; + double delta; + } sequential; + }; + + /// Active variant kind + Kind kind; + + /// Random values are cast to integer after scaling by this power of two + int int_scale; + + // + // Methods + // + + Distribution() : kind(Invalid), int_scale(0) {} + +/// Configures distribution as uniform random + Distribution &set_uniform(double _min, double _max, int _int_scale = 0, double _pnan = 0) { + kind = Uniform; + uniform.min = _min; + uniform.max = _max; + int_scale = _int_scale; + uniform.pnan = _pnan; + return *this; + } + + /// Configures distribution as Gaussian distribution + Distribution &set_gaussian(double _mean, double _stddev, int _int_scale = 0, double _pnz = 1.0) { + kind = Gaussian; + gaussian.mean = _mean; + gaussian.stddev = _stddev; + gaussian.pnz = _pnz; + gaussian.pnzA = _pnz; + gaussian.pnzB = _pnz; + gaussian.pnzC = _pnz; + int_scale = _int_scale; + return *this; + } + + /// Sets identity + Distribution &set_identity() { + kind = Identity; + return *this; + } + + /// Sets sequential + Distribution &set_sequential(double start, double delta, int _int_scale = 0) { + kind = Sequential; + sequential.start = start; + sequential.delta = delta; + int_scale = _int_scale; + return *this; + } +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints a Distribution to ostream +inline std::ostream &operator<<(std::ostream &out, cutlass::Distribution const &dist) { + switch (dist.kind) { + case cutlass::Distribution::Uniform: + out << "uniform, min: " << dist.uniform.min << ", max: " << dist.uniform.max + << ", pnan: " << dist.uniform.pnan; + break; + case cutlass::Distribution::Gaussian: + out << "gaussian, mean: " << dist.gaussian.mean << ", stddev: " << dist.gaussian.stddev + << ", pnzA: " << dist.gaussian.pnzA << ", pnzB: " + << dist.gaussian.pnzB << ", pnzC: " << dist.gaussian.pnzC; + break; + case cutlass::Distribution::Identity: + out << "identity"; + break; + case cutlass::Distribution::Sequential: + out << "sequential"; + break; + default: + out << "unknown"; + } + + out << ", int_scale: " << dist.int_scale; + + return out; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b7df6cb1c465a312d76566768cb79fcdfffee4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/exceptions.h @@ -0,0 +1,69 @@ +/****************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +#pragma once + +/** + * \file + * \brief C++ exception semantics for CUDA error codes + */ + +#include +#include +#include + +#include "cutlass/platform/platform.h" + +namespace cutlass { + +/// C++ exception wrapper for CUDA \p cudaError_t +class cuda_exception : public std::exception { + public: + /// Constructor + cuda_exception(const char* msg = "", cudaError_t err = cudaErrorUnknown) : msg(msg), err(err) {} + + /// Returns the underlying CUDA \p cudaError_t + cudaError_t cudaError() const { return err; } + + protected: + /// Explanatory string + const char* msg; + + /// Underlying CUDA \p cudaError_t + cudaError_t err; +}; + +/// Writes a cuda_exception instance to an output stream +inline std::ostream& operator<<(std::ostream& out, cuda_exception const& e) { + return out << e.what() << ": " << cudaGetErrorString(e.cudaError()); +} + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be2264466e350c062900a50e27e923847186d084 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/gett_commandline.hpp @@ -0,0 +1,369 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT command line parser to gather semantic modes, their stride order, and extents. +*/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/util/command_line.h" + +namespace cutlass { + +// Output shortcuts +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, std::vector data) { + for (auto& a : data) os << a << " "; + return os; +} + +struct GettCommandLine { + struct GettProblem { + using extent_type = int; + using stride_type = int64_t; + + // Row modes: appear in A and C/D + std::vector M; + std::vector ldAm; + std::vector ldCm; + + // Column modes: appear in B and C/D + std::vector N; + std::vector ldBn; + std::vector ldCn; + + // Reduction modes: appear in A and B + std::vector K; + std::vector ldAk; + std::vector ldBk; + + // Batch modes: appear in all in/out tensors + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + }; + + static GettProblem + parse(int argc, char const* argv[], bool parse_verbose = false) { + using extent_type = typename GettProblem::extent_type; + using stride_type = typename GettProblem::stride_type; + + cutlass::CommandLine cmd(argc, argv); + + // modeA + std::vector a_mode; + cmd.get_cmd_line_arguments("modeA", a_mode); + + // modeB + std::vector b_mode; + cmd.get_cmd_line_arguments("modeB", b_mode); + + // modeC + std::vector c_mode; + cmd.get_cmd_line_arguments("modeC", c_mode); + + + // mode_sizes + std::map mode_size; + // First, initialize all modes in a, b, c to make sure they're in map + for (char a : a_mode) mode_size[a] = 1; + for (char b : b_mode) mode_size[b] = 1; + for (char c : c_mode) mode_size[c] = 1; + + // Then, overwrite the ones in -extent + std::vector > extent_tokens; + cmd.get_cmd_line_argument_pairs("extents", extent_tokens); + for (auto e : extent_tokens) { + if (std::get<0>(e).size() > 1) { + std::cerr << "ERROR: Mode name must only be 1 character long.\n"; + print_usage(); + exit(1); + } + char label = std::get<0>(e)[0]; + int size = std::stoi(std::get<1>(e)); + mode_size[label] = size; + } + + // Print out symbolic modes and their extents + if (parse_verbose) { + std::cout << "C_" << c_mode << " = A_" << a_mode << " * B_" << b_mode << "\n"; + for (auto e : mode_size) std::cout << " " << std::get<0>(e) << " : " << std::get<1>(e) << "\n"; + } + + // + // Collect/Compute strides + // + + std::map mode_ldA; + std::map mode_ldB; + std::map mode_ldC; + + { + stride_type current; + + current = 1; + for (char a : a_mode) { mode_ldA[a] = current; current *= mode_size[a]; } + + current = 1; + for (char b : b_mode) { mode_ldB[b] = current; current *= mode_size[b]; } + + current = 1; + for (char c : c_mode) { mode_ldC[c] = current; current *= mode_size[c]; } + } + + // + // Collect mode categories + // + + std::vector row_mode; // rows + std::vector col_mode; // columns + std::vector red_mode; // reductions + std::vector bat_mode; // batches + + { + std::vector a_label = a_mode; + std::vector b_label = b_mode; + std::vector c_label = c_mode; + + std::sort(std::begin(a_label), std::end(a_label)); + std::sort(std::begin(b_label), std::end(b_label)); + std::sort(std::begin(c_label), std::end(c_label)); + + // std::set_intersections to find semantic category of each symbolic mode + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(row_mode)); + + std::set_intersection(std::begin(b_label), std::end(b_label), + std::begin(c_label), std::end(c_label), + std::back_inserter(col_mode)); + + std::set_intersection(std::begin(a_label), std::end(a_label), + std::begin(b_label), std::end(b_label), + std::back_inserter(red_mode)); + + std::set_intersection(std::begin(row_mode), std::end(row_mode), + std::begin(col_mode), std::end(col_mode), + std::back_inserter(bat_mode)); + + // std::set_difference to remove batch modes from other semantic modes + for (char l : bat_mode) { + row_mode.erase(std::remove(std::begin(row_mode), std::end(row_mode), l), std::end(row_mode)); + col_mode.erase(std::remove(std::begin(col_mode), std::end(col_mode), l), std::end(col_mode)); + red_mode.erase(std::remove(std::begin(red_mode), std::end(red_mode), l), std::end(red_mode)); + } + } + + // Print out the semantic association of each symbolic mode + if (parse_verbose) { + std::cout << " rows : " << row_mode << '\n'; + std::cout << " cols : " << col_mode << '\n'; + std::cout << " reds : " << red_mode << '\n'; + std::cout << " bats : " << bat_mode << '\n'; + } + + // + // Permute modes + // + + // Permute the batched modes to promote coalescing + // Sort the batched modes by min(ldAl,ldBl) and in case of a tie by the size + std::sort(std::begin(bat_mode), std::end(bat_mode), [&](char l1, char l2) { + return std::tie(std::min(mode_ldA[l1],mode_ldB[l1]),mode_size[l1]) + < std::tie(std::min(mode_ldA[l2],mode_ldB[l2]),mode_size[l2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector L; + std::vector ldAl; + std::vector ldBl; + std::vector ldCl; + for (char l : bat_mode) { + L.push_back(mode_size[l]); + ldAl.push_back(mode_ldA[l]); + ldBl.push_back(mode_ldB[l]); + ldCl.push_back(mode_ldC[l]); + } + + // Permute the reduction modes to promote coalescing + // Sort the reduction modes by min(ldAk,ldBk) and in case of a tie by the size + std::sort(std::begin(red_mode), std::end(red_mode), [&](char k1, char k2) { + return std::tie(std::min(mode_ldA[k1],mode_ldB[k1]),mode_size[k1]) + < std::tie(std::min(mode_ldA[k2],mode_ldB[k2]),mode_size[k2]); + }); + // Compute sizes and strides of ordered reduction modes + std::vector K; + std::vector ldAk; + std::vector ldBk; + for (char k : red_mode) { + K.push_back(mode_size[k]); + ldAk.push_back(mode_ldA[k]); + ldBk.push_back(mode_ldB[k]); + } + + // Permute the row modes to promote coalescing + // Sort the row modes by min(ldAm,ldCm) and in case of a tie by ldAm + std::sort(std::begin(row_mode), std::end(row_mode), [&](char m1, char m2) { + return std::tie(std::min(mode_ldA[m1],mode_ldC[m1]),mode_ldA[m1]) + < std::tie(std::min(mode_ldA[m2],mode_ldC[m2]),mode_ldA[m2]); + }); + // Compute sizes and strides of ordered row modes + std::vector M; + std::vector ldAm; + std::vector ldCm; + for (char m : row_mode) { + M.push_back(mode_size[m]); + ldAm.push_back(mode_ldA[m]); + ldCm.push_back(mode_ldC[m]); + } + + // Permute the col modes to promote coalescing + // Sort the col modes by min(ldBn,ldCn) and in case of a tie by ldBn + std::sort(std::begin(col_mode), std::end(col_mode), [&](char n1, char n2) { + return std::tie(std::min(mode_ldB[n1],mode_ldC[n1]),mode_ldB[n1]) + < std::tie(std::min(mode_ldB[n2],mode_ldC[n2]),mode_ldB[n2]); + }); + // Compute sizes and strides of ordered col modes + std::vector N; + std::vector ldBn; + std::vector ldCn; + for (char n : col_mode) { + N.push_back(mode_size[n]); + ldBn.push_back(mode_ldB[n]); + ldCn.push_back(mode_ldC[n]); + } + + if (parse_verbose) { + std::cout << "C_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " = A_"; + if (! row_mode.empty()) { + std::cout << "(" << row_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << " * B_"; + if (! col_mode.empty()) { + std::cout << "(" << col_mode << ")"; + } + if (! red_mode.empty()) { + std::cout << "(" << red_mode << ")"; + } + if (! bat_mode.empty()) { + std::cout << "(" << bat_mode << ")"; + } + std::cout << '\n'; + + int M_size = std::accumulate(std::begin(M), std::end(M), 1, std::multiplies<>{}); + int N_size = std::accumulate(std::begin(N), std::end(N), 1, std::multiplies<>{}); + int K_size = std::accumulate(std::begin(K), std::end(K), 1, std::multiplies<>{}); + int L_size = std::accumulate(std::begin(L), std::end(L), 1, std::multiplies<>{}); + + std::cout << " M : (" << M_size << ") "; + for (char m : row_mode) std::cout << m << ":" << mode_size[m] << " "; + std::cout << '\n'; + std::cout << " N : (" << N_size << ") "; + for (char n : col_mode) std::cout << n << ":" << mode_size[n] << " "; + std::cout << '\n'; + std::cout << " K : (" << K_size << ") "; + for (char k : red_mode) std::cout << k << ":" << mode_size[k] << " "; + std::cout << '\n'; + std::cout << " L : (" << L_size << ") "; + for (char l : bat_mode) std::cout << l << ":" << mode_size[l] << " "; + std::cout << '\n'; + + std::cout << " ldAm : " << ldAm << '\n'; + std::cout << " ldAk : " << ldAk << '\n'; + std::cout << " ldAl : " << ldAl << '\n'; + std::cout << " ldBn : " << ldBn << '\n'; + std::cout << " ldBk : " << ldBk << '\n'; + std::cout << " ldBl : " << ldBl << '\n'; + std::cout << " ldCm : " << ldCm << '\n'; + std::cout << " ldCn : " << ldCn << '\n'; + std::cout << " ldCl : " << ldCl << '\n'; + } + + return {M, ldAm, ldCm, + N, ldBn, ldCn, + K, ldAk, ldBk, + L, ldAl, ldBl, ldCl}; + } + + static void + print_usage() { + std::cout << + "GETT problem command line parser:\n" + " --modeA=\n" + " A comma delimited list of characters that correspond to the row, reduction, and batch modes in A tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeB=\n" + " A comma delimited list of characters that correspond to the column, reduction, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --modeC=\n" + " A comma delimited list of characters that correspond to the row, column, and batch modes in B tensor.\n" + " The semantic association of each symbolic mode is determined automatically.\n\n" + + " --extents=\n" + " A command delimited list of symbolic mode and its corresponding extent.\n" + " Extents are defaulted to 1 if any are not provided.\n\n" + + "Example usage: gett.exe --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096\n"; + } +}; + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp new file mode 100644 index 0000000000000000000000000000000000000000..58d08b860c9e665d170fd022ed0d95875e029019 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/helper_cuda.hpp @@ -0,0 +1,116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +namespace cute +{ + +void +device_init(int device_id, bool quiet = false) +{ + cudaDeviceProp device_prop; + std::size_t device_free_physmem; + std::size_t device_total_physmem; + + CUTE_CHECK_ERROR(cudaSetDevice(device_id)); + CUTE_CHECK_ERROR(cudaMemGetInfo(&device_free_physmem, &device_total_physmem)); + CUTE_CHECK_ERROR(cudaGetDeviceProperties(&device_prop, device_id)); + + if (device_prop.major < 1) { + fprintf(stderr, "Device does not support CUDA.\n"); + exit(1); + } + + //float device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000; + + if (!quiet) { + printf("Using device %d: %s (SM%d, %d SMs)\n", + device_id, device_prop.name, + device_prop.major * 10 + device_prop.minor, + device_prop.multiProcessorCount); + fflush(stdout); + } +} + +/** + * Convert the SM version (e.g. v7.0, v7.5) to the physical number of cores. + */ +inline int +_ConvertSMVer2Cores(int major, int minor) +{ + // Defines for GPU Architecture types (using the SM version to determine + // the # of cores per SM + typedef struct { + int SM; // 0xMm (hexadecimal notation), M = SM Major version, + // and m = SM minor version + int Cores; + } sSMtoCores; + + sSMtoCores nGpuArchCoresPerSM[] = { + {0x30, 192}, + {0x32, 192}, + {0x35, 192}, + {0x37, 192}, + {0x50, 128}, + {0x52, 128}, + {0x53, 128}, + {0x60, 64}, + {0x61, 128}, + {0x62, 128}, + {0x70, 64}, + {0x72, 64}, + {0x75, 64}, + {-1, -1}}; + + int index = 0; + + while (nGpuArchCoresPerSM[index].SM != -1) { + if (nGpuArchCoresPerSM[index].SM == ((major << 4) + minor)) { + return nGpuArchCoresPerSM[index].Cores; + } + index++; + } + + // If we don't find the values, we default use the previous one + // to run properly + printf("MapSMtoCores for SM %d.%d is undefined." + " Default to use %d Cores/SM\n", + major, minor, nGpuArchCoresPerSM[index - 1].Cores); + + return nGpuArchCoresPerSM[index - 1].Cores; +} + +} // end namespace cute diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7718059dfaea0c77d7ebf67789f307b4ca0cf6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_reorder.h @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief reorder data from the host side +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +/// This is needed for the interleaved integer tensor core kernels. The purpose +/// is to use skip the shared memory part in the epilogue. +template +void reorder_column(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + const int InstructionShapeCol = 8; + // 4 threads per Quad + const int ElementsPerThread = InstructionShapeCol / 4; + // 4 threads per Quad + const int ReorderedElementsPerThread = + Interleaved / 4; + + for (int n = 0; n < problem_size.n(); n++) { + for (int k = 0; k < problem_size.k(); k++) { + dest.at({k, (n / Interleaved) * Interleaved + + ((n % ReorderedElementsPerThread) / ElementsPerThread) * + InstructionShapeCol + + ((n % Interleaved) / ReorderedElementsPerThread) * + ElementsPerThread + + (n % ElementsPerThread)}) = src.at({k, n}); + } + } +} + +template +void reorder_convK(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + + TensorRef> mappedDest(dest.data(), dest.stride(0)); + TensorRef> mappedSrc(src.data(), src.stride(0)); + + reorder_column( + mappedDest, mappedSrc, problem_size); +} + +/// This is needed for the sparse tensor core kernels. The purpose +/// is to use ldmatrix to load from shared memory to the register file. +template +void reorder_meta(TensorRef dest, + TensorRef src, + cutlass::gemm::GemmCoord problem_size) { + for (int m = 0; m < problem_size.m(); m++) { + for (int k = 0; k < problem_size.k(); k++) { + // First reorder the rows. + int group = (sizeof(Element) == 2) ? 32 : 16; + int interweave = (sizeof(Element) == 2) ? 4 : 2; + + int dest_row = m / group * group + (m % 8) * interweave + (m % group) / 8; + int dest_col = k; + + // Next swizzle the 2x2 blocks from Z to N. + if (((dest_row % 2) == 0) && ((dest_col % 2) == 1)) { + ++dest_row; + --dest_col; + } else if (((dest_row % 2) == 1) && ((dest_col % 2) == 0)) { + --dest_row; + ++dest_col; + } + + dest.at({dest_row, dest_col}) = src.at({m, k}); + } + } +} +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..3226055ad0836e7a3059340ff16d54594987e0c8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensor { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRef; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorView; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + +private: + using StorageUnit = typename platform::conditional_t, uint8_t, // Avoid the std::vector specialization + typename platform::conditional_t::value % 8 == 0, // Handle subbyte types + Element, uint8_t>>; + using StorageContainerCalculator = cutlass::detail::StorageContainerCalculator; + static constexpr int kContainerTypeNumBits = StorageContainerCalculator::kContainerTypeNumBits; + static constexpr int kContainerTypeNumLogicalElements = StorageContainerCalculator::kContainerTypeNumLogicalElements; + static constexpr int kContainerTypeNumBytes = StorageContainerCalculator::kContainerTypeNumBytes; + static constexpr int kContainerTypeNumStorageUnit = StorageContainerCalculator::kContainerTypeNumStorageUnit; + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + /// number of containers + size_t count_to_container_storage_unit_count(size_t count) { + return (count + kContainerTypeNumLogicalElements - 1) / kContainerTypeNumLogicalElements * kContainerTypeNumStorageUnit; + } + +public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensor() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensor( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensor( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensor() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve(count=" << count << ", device_backed_=" << (device_backed_ ? "true" : "false") << ")"); +#endif + + device_.reset(); + host_.clear(); + + size_t count_container = count_to_container_storage_unit_count(count); +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: host_.resize(" << count_container << ")"); +#endif + host_.resize(count_container); + + // Allocate memory + StorageUnit* device_memory = nullptr; + if (device_backed_) { +#if (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST("cutlass::HostTensor::reserve: device_memory::allocate(" << count_container << ")"); +#endif + device_memory = device_memory::allocate(count_container); + } + device_.reset(device_memory, device_backed_ ? count_container : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + LongIndex new_size_container = count_to_container_storage_unit_count((layout_.capacity(extent_))); + + if (static_cast(new_size_container) > host_.size()) { + reserve(new_size, device_backed_); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the logical number of elements stored in the host tensor + size_t size() const { + return layout_.capacity(extent_); + } + + /// Returns the logical capacity in terms of number of elements. May be larger than the size(). + LongIndex capacity() const { + return host_.size() / kContainerTypeNumStorageUnit * kContainerTypeNumLogicalElements; + } + + /// Gets pointer to host data + Element * host_data() { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return reinterpret_cast(host_.data()); } + + /// Gets pointer to host data with a pointer offset + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return ReferenceFactory::get(host_data(), idx); + } + + /// Gets pointer to device data + Element * device_data() { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data + Element const * device_data() const { return reinterpret_cast(device_.get()); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(device_data(), ptr_element_offset); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_); } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + + /// Returns the layout object + Layout & layout() { + return layout_; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride vector + Stride & stride() { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Returns the layout object's stride in a given physical dimension + LongIndex & stride(int dim) { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_.data(), device_.get(), device_.size()); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_.get(), host_.data(), host_.size()); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + host_.data(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + device_.get(), reinterpret_cast(ptr_device), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + device_.get(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + host_.data(), reinterpret_cast(ptr_host), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_host( + reinterpret_cast(ptr_host), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_device_to_device( + reinterpret_cast(ptr_device), device_.get(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_to_device( + reinterpret_cast(ptr_device), host_.data(), container_count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + size_t container_count = count_to_container_storage_unit_count(count); + device_memory::copy_host_to_host( + reinterpret_cast(ptr_host), host_.data(), container_count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..ca770e4d76cfe2df16309baca0b2de8ab6de98c4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_tensor_planar_complex.h @@ -0,0 +1,591 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief HostTensor contributes management for both host and device memory. + + HostTensor allocates host and device memory upon construction. Basic element-wise operations on + host memory synchronize device memory automatically. Explicit copy operations provide abstractions + for CUDA memcpy operations. + + Call {host, device}_{data, ref, view}() for accessing host or device memory. + + See cutlass/tensor_ref.h and cutlass/tensor_view.h for more details. +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/tensor_ref_planar_complex.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "device_memory.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Host tensor +template < + /// Data type of element stored within tensor (concept: NumericType) + typename Element_, + /// Defines a mapping from logical coordinate to linear memory (concept: Layout) + typename Layout_ +> +class HostTensorPlanarComplex { +public: + + /// Data type of individual access + using Element = Element_; + + /// Mapping function from logical coordinate to linear memory + using Layout = Layout_; + + /// Logical rank of tensor index space + static int const kRank = Layout::kRank; + + /// Index type + using Index = typename Layout::Index; + + /// Long index used for pointer offsets + using LongIndex = typename Layout::LongIndex; + + /// Coordinate in logical tensor space + using TensorCoord = typename Layout::TensorCoord; + + /// Layout's stride vector + using Stride = typename Layout::Stride; + + /// Tensor reference to device memory + using TensorRef = TensorRefPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + /// Tensor reference to device memory + using TensorView = TensorViewPlanarComplex; + + /// Tensor reference to constant device memory + using ConstTensorView = typename TensorView::ConstTensorView; + + /// Reference to element in tensor + using Reference = typename TensorRef::Reference; + + /// Constant reference to element in tensor + using ConstReference = typename ConstTensorRef::Reference; + + private: + + // + // Data members + // + + /// Extent of tensor in logical dimensions + TensorCoord extent_; + + /// Layout object + Layout layout_; + + /// Host-side memory allocation + std::vector host_; + + /// Device-side memory + device_memory::allocation device_; + + public: + // + // Device and Host Methods + // + + /// Default constructor + HostTensorPlanarComplex() {} + + /// Constructs a tensor given an extent. Assumes a packed layout + HostTensorPlanarComplex( + TensorCoord const &extent, + bool device_backed = true + ) { + + this->reset(extent, Layout::packed(extent), device_backed); + } + + /// Constructs a tensor given an extent and layout + HostTensorPlanarComplex( + TensorCoord const &extent, + Layout const &layout, + bool device_backed = true + ) { + + this->reset(extent, layout, device_backed); + } + + ~HostTensorPlanarComplex() { } + + /// Clears the HostTensor allocation to size/capacity = 0 + void reset() { + extent_ = TensorCoord(); + layout_ = Layout::packed(extent_); + + host_.clear(); + device_.reset(); + } + + /// Resizes internal memory allocations without affecting layout or extent + void reserve( + size_t count, ///< size of tensor in elements + bool device_backed_ = true) { ///< if true, device memory is also allocated + + device_.reset(); + host_.clear(); + + host_.resize(count * 2); + + // Allocate memory + Element* device_memory = nullptr; + if (device_backed_) { + device_memory = device_memory::allocate(count * 2); + } + device_.reset(device_memory, device_backed_ ? count * 2 : 0); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + reserve(size_t(layout_.capacity(extent_)), device_backed_); + } + + /// Updates the extent and layout of the HostTensor. Allocates memory according to the new + /// extent and layout. Assumes a packed tensor configuration. + void reset( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + reset(extent, Layout::packed(extent), device_backed_); + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). + void resize( + TensorCoord const &extent, ///< extent of logical tensor + Layout const &layout, ///< layout object of tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + extent_ = extent; + layout_ = layout; + + LongIndex new_size = size_t(layout_.capacity(extent_)); + + if (static_cast(new_size * 2) > host_.size()) { + reserve(new_size); + } + } + + /// Changes the size of the logical tensor. Only allocates memory if new capacity exceeds reserved capacity. + /// To force allocation, call reset(). Note, this form of resize() assumes a packed tensor configuration. + void resize( + TensorCoord const &extent, ///< extent of logical tensor + bool device_backed_ = true) { ///< if true, device memory is also allocated. + + resize(extent, Layout::packed(extent), device_backed_); + } + + /// Returns the number of elements stored in the host tensor + size_t size() const { + return host_.size() / 2; + } + + /// Returns the logical capacity based on extent and layout. May differ from size(). + LongIndex capacity() const { + return layout_.capacity(extent_); + } + + /// Stride between real and imaginary parts + LongIndex imaginary_stride() const { + return host_.size() / 2; + } + + /// Gets pointer to host data + Element * host_data() { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element * host_data_imag() { return host_.data() + imaginary_stride(); } + + /// Gets pointer to host data with a pointer offset + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return host_data() + ptr_element_offset; } + + /// Gets pointer to host data with a pointer offset + Element * host_data_imag_ptr_offset(LongIndex ptr_element_offset) { return host_data_imag() + ptr_element_offset; } + + /// Gets a reference to an element in host memory + Reference host_data(LongIndex idx) { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to host data + Element const * host_data() const { return host_.data(); } + + /// Gets pointer to host data imaginary part + Element const * host_data_imag() const { return host_.data() + imaginary_stride(); } + + /// Gets a constant reference to an element in host memory + ConstReference host_data(LongIndex idx) const { + return PlanarComplexReference(host_data() + idx, host_data_imag() + idx); + } + + /// Gets pointer to device data + Element * device_data() { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element * device_data_ptr_offset(LongIndex ptr_element_offset) { return device_.get() + ptr_element_offset; } + + /// Gets pointer to device data + Element const * device_data() const { return device_.get(); } + + /// Gets pointer to device data with a pointer offset + Element const * device_data_ptr_offset(LongIndex ptr_element_offset) const { return device_.get() + ptr_element_offset; } + + /// Gets a pointer to the device data imaginary part + Element * device_data_imag() { return device_.get() + imaginary_stride(); } + + /// Accesses the tensor reference pointing to data + TensorRef host_ref(LongIndex ptr_element_offset=0) { + return TensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_real() { + return cutlass::TensorRef(host_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef host_ref_imag() { + return cutlass::TensorRef(host_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef host_ref(LongIndex ptr_element_offset=0) const { + return ConstTensorRef(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + TensorRef device_ref(LongIndex ptr_element_offset=0) { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Accesses the tensor reference pointing to data + ConstTensorRef device_ref(LongIndex ptr_element_offset=0) const { + return TensorRef(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride()); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_real() { + return cutlass::TensorRef(device_data(), layout_); + } + + /// Returns a tensor reference to the real part of the tensor + cutlass::TensorRef device_ref_imag() { + return cutlass::TensorRef(device_data_ptr_offset(imaginary_stride()), layout_); + } + + /// Accesses the tensor reference pointing to data + TensorView host_view(LongIndex ptr_element_offset=0) { + return TensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView host_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(host_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_real() { + return cutlass::TensorView(host_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView host_view_imag() { + return cutlass::TensorView(host_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + TensorView device_view(LongIndex ptr_element_offset=0) { + return TensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + ConstTensorView device_view(LongIndex ptr_element_offset=0) const { + return ConstTensorView(device_data_ptr_offset(ptr_element_offset), layout_, imaginary_stride(), extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_real() { + return cutlass::TensorView(device_data(), layout_, extent_); + } + + /// Accesses the tensor reference pointing to data + cutlass::TensorView device_view_imag() { + return cutlass::TensorView(device_data_ptr_offset(imaginary_stride()), layout_, extent_); + } + + /// Returns true if device memory is allocated + bool device_backed() const { + return (device_.get() == nullptr) ? false : true; + } + + /// Returns the layout object + Layout layout() const { + return layout_; + } + + /// Returns the layout object's stride vector + Stride stride() const { + return layout_.stride(); + } + + /// Returns the layout object's stride in a given physical dimension + Index stride(int dim) const { + return layout_.stride().at(dim); + } + + /// Computes the offset of an index from the origin of the tensor + LongIndex offset(TensorCoord const& coord) const { + return layout_(coord); + } + + /// Returns a reference to the element at the logical Coord in host memory + Reference at(TensorCoord const& coord) { + return host_data(offset(coord)); + } + + /// Returns a const reference to the element at the logical Coord in host memory + ConstReference at(TensorCoord const& coord) const { + return host_data(offset(coord)); + } + + /// Returns the extent of the tensor + TensorCoord extent() const { + return extent_; + } + + /// Returns the extent of the tensor + TensorCoord & extent() { + return extent_; + } + + /// Copies data from device to host + void sync_host() { + if (device_backed()) { + device_memory::copy_to_host( + host_data(), device_data(), imaginary_stride() * 2); + } + } + + /// Copies data from host to device + void sync_device() { + if (device_backed()) { + device_memory::copy_to_device( + device_data(), host_data(), imaginary_stride() * 2); + } + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_host( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + host_data(), ptr_device_real, count); + + device_memory::copy_to_host( + host_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_device_to_device( + Element const* ptr_device_real, ///< source device memory + Element const* ptr_device_imag, ///< source device memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + device_data(), ptr_device_real, count); + + device_memory::copy_device_to_device( + device_data_imag(), ptr_device_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_device( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + device_data(), ptr_host_real, count); + + device_memory::copy_to_device( + device_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_in_host_to_host( + Element const* ptr_host_real, ///< source host memory + Element const* ptr_host_imag, ///< source host memory + LongIndex count = -1) { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + host_data(), ptr_host_real, count); + + device_memory::copy_host_to_host( + host_data_imag(), ptr_host_imag, count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_host( + Element * ptr_host_real, ///< source device memory + Element * ptr_host_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_host( + ptr_host_real, device_data(), count); + + device_memory::copy_to_host( + ptr_host_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_device_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_device_to_device( + ptr_device_real, device_data(), count); + + device_memory::copy_device_to_device( + ptr_device_imag, device_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_device( + Element * ptr_device_real, ///< source device memory + Element * ptr_device_imag, ///< source device memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_to_device( + ptr_device_real, host_data(), count); + + device_memory::copy_to_device( + ptr_device_imag, host_data_imag(), count); + } + + /// Copy data from a caller-supplied device pointer into host memory. + void copy_out_host_to_host( + Element * ptr_host_real, ///< source host memory + Element * ptr_host_imag, ///< source host memory + LongIndex count = -1) const { ///< number of elements to transfer; if negative, entire tensor is overwritten. + + if (count < 0) { + count = capacity(); + } + else { + count = __NV_STD_MIN(capacity(), count); + } + + device_memory::copy_host_to_host( + ptr_host_real, host_data(), count); + + device_memory::copy_host_to_host( + ptr_host_imag, host_data_imag(), count); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h new file mode 100644 index 0000000000000000000000000000000000000000..9cd62927432c65ce1f0187f46306f7e1198a1182 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/host_uncompress.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief uncompress sparse matrix from the host side +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { + +// uncompress sparse tensor core A matrix +template +void uncompress(TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef tensor_e, int row, int col) { + // How many uncompressed data we can get with ElementE meta data + int DecompressedElementsPerElementE = + 256 / cutlass::sizeof_bits::value; + + // Process 4bit meta data a time + int step; + + // 1:2 or 2:4 or 4:8 + int a, b; + + if (cutlass::sizeof_bits::value == 4) { + step = 8; + a = 4; + b = 8; + } else if (cutlass::sizeof_bits::value == 8) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 16) { + step = 4; + a = 2; + b = 4; + } else if (cutlass::sizeof_bits::value == 32) { + step = 2; + a = 1; + b = 2; + } + + int ElementsPerE = (cutlass::sizeof_bits::value == 4) ? 2 : 1; + + for (int r = 0; r < row; ++r) { + for (int c = 0; c < (col / DecompressedElementsPerElementE); ++c) { + + ElementE meta = tensor_e.at(MatrixCoord(r, c)); + + for (int i = 0; i < DecompressedElementsPerElementE; i += step) { + int e = (meta >> (i / step * 4)) & 0xf; + int idx0 = e & 0x3; + int idx1 = e >> 2; + + if (a == 1) idx0 = idx0 / 2; + + for (int ii = 0; ii < step; ii += ElementsPerE) { + int real_col = + c * DecompressedElementsPerElementE + i + ii; + int compressed_col = (real_col / b) * a; + + if (ii == (idx0 * ElementsPerE)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at(MatrixCoord(r, compressed_col + 1)); + } else if ((ii == (idx1 * ElementsPerE)) && (a != 1)) { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + tensor_a.at(MatrixCoord(r, compressed_col + ElementsPerE)); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + tensor_a.at( + MatrixCoord(r, compressed_col + ElementsPerE + 1)); + } else { + uncompressed_tensor_a.at(MatrixCoord(r, real_col)) = + ElementA(0); + if (ElementsPerE == 2) + uncompressed_tensor_a.at(MatrixCoord(r, real_col + 1)) = + ElementA(0); + } + } + } + } + } +} + +// uncompress ELL block sparse matrix +template +void uncompress_ell_block_sparse( + TensorRef uncompressed_tensor_a, + TensorRef tensor_a, + TensorRef ell_idx, + int rows, int cols, + int ell_num_cols, int ell_blocksize) { + + for (int r = 0; r < rows / ell_blocksize; ++r) { + for (int c = 0; c < ell_num_cols / ell_blocksize; ++c) { + + ElementE idx = ell_idx.at(MatrixCoord(r, c)); + + if (idx != -1) { + int row_begin = r * ell_blocksize; + int col_begin_real = idx * ell_blocksize; + int col_begin = c * ell_blocksize; + + for (int i = 0; i < ell_blocksize; ++i) { + for (int j = 0; j < ell_blocksize; ++j) { + uncompressed_tensor_a.at(MatrixCoord(row_begin + i, col_begin_real + j)) = + tensor_a.at( + MatrixCoord(row_begin + i, col_begin +j)); + } + } + } + } + } +} + +} // namespace cutlass + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h new file mode 100644 index 0000000000000000000000000000000000000000..6b72b043fc0c1271cf9f12e5cb9a81d29659cb0a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/index_sequence.h @@ -0,0 +1,38 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +// integer_sequence moved to cutlass/numeric_types.h + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..43f5a3f92d29f229703cc4c5f9071c11d0f89df4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/mixed_dtype_utils.hpp @@ -0,0 +1,472 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for mixed input data type kernels. +*/ + +#pragma once + +#include +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" +#include "cutlass/cutlass.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cute/util/type_traits.hpp" + +namespace cutlass { + +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleBroadCastLayout, + class ThrLayout> +__global__ void dequantize_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + cute::Tensor gmem_op_q = cute::make_tensor(cute::make_gmem_ptr(q_buffer), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = cute::make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = cute::make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, cute::multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, cute::plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template < + class QuantizedElement, + class DequantizedElement, + class OperandLayout, + class ElementScale, + class ElementZero, + class ScaleLayout> +static void dequantize(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size, + cudaStream_t &stream) { + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaStreamSynchronize(stream)); +} + +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +static bool unified_encode_int4b(cutlass::int4b_t const *block_in, cutlass::int4b_t *block_out, const size_t block_size) { + + using StorageType = cutlass::int4b_t::Storage; + constexpr int pack = cute::sizeof_bits_v / 4; + const size_t host_buf_size = block_size / pack; + std::vector host_buf(host_buf_size); + cutlass::device_memory::copy_to_host(host_buf.data(), (StorageType *) block_in, host_buf_size); + + for (auto&& d : host_buf) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; i++) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*) block_out, host_buf.data(), host_buf_size); + return true; +} + +template +static bool pack_scale_fp8(ElementScale const *block_in, cutlass::Array *block_out, const size_t block_size) { + std::vector data_in(block_size); + std::vector> data_out(block_size); + + try { + cutlass::device_memory::copy_to_host(data_in.data(), block_in, block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + + for (size_t i = 0; i < block_size; i++) { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + + try { + cutlass::device_memory::copy_to_device(block_out, data_out.data(), block_size); + } + catch (cutlass::cuda_exception const& e) { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} + +template +struct UnderlyingElement { + using type = T; +}; + +template +struct UnderlyingElement> { + using type = typename T::Element; +}; + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template , + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(flatten(s), [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, Int>{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} + +#undef CUDA_CHECK + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp new file mode 100644 index 0000000000000000000000000000000000000000..811ba152ab7c6e8fafc1cebdbb3726798fd16b3c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/packed_stride.hpp @@ -0,0 +1,570 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for packing constructing canonical CuTe stride types for 3.x mainloop params. +*/ + +#pragma once + +#include "cute/layout.hpp" +#include "cute/container/array.hpp" // cute::array +#include "cutlass/conv/convolution.h" // cutlass::conv::Operator + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides without batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride> +make_cute_packed_stride(cute::Stride> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride(cute::Stride, IntT> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with batch mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, int64_t> +make_cute_packed_stride(cute::Stride, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, int64_t> +make_cute_packed_stride(cute::Stride, IntT, int64_t> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + int batch_count = cute::get<2>(shape_MKL); + if (batch_count > 1) { + cute::get<2>(s_copy) = static_cast(cute::get<0>(shape_MKL) * cute::get<1>(shape_MKL)); + } + else { + cute::get<2>(s_copy) = static_cast(0); + } + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides with group mode + +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride(cute::Stride, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<0>(s_copy) = static_cast(cute::get<1>(shape_MKL)); + return s_copy; +} + +template +CUTLASS_HOST_DEVICE +cute::Stride, StrideIntT, cute::Int<0>> +make_cute_packed_stride(cute::Stride, StrideIntT, cute::Int<0>> s, cute::Shape shape_MKL) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + auto s_copy = s; + cute::get<1>(s_copy) = static_cast(cute::get<0>(shape_MKL)); + return s_copy; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Strides for convolutions + +// Output cutlass::layout::TensorNDHWC -> rank-3 stride (InT,_1,_0) +// Note: For fprop/dgrad kernel, strides are assumed to be layout right in NZPQK/NDHWC order +// and therefore can be coalesced to just q/w. For wgrad kernel, strides are assumed to be layout +// right in KTRSC order and can be coalesced to just k. +// We enforce this condition here with asserts. +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, cute::Int<0>> s, + cute::array shape_output, + cute::array stride_output, + cutlass::conv::Operator conv_op) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + static_assert(RankT_ >= 3u); + constexpr static int RankT = static_cast(RankT_); + + assert(stride_output[RankT-1] == 1); + cute::for_each(cute::make_seq{}, [&](auto i) { + assert(stride_output[i] == shape_output[i+1] * stride_output[i+1]); + }); + + auto s_copy = s; + cute::get<0>(s_copy) = (conv_op == cutlass::conv::Operator::kWgrad) ? + stride_output[0] : + stride_output[RankT-2]; + return s_copy; +} + +// +// Activation tensor ((w, h, d, n), _1) for fprop kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride ((W,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nwc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_nwc[1]; + cute::get<0,1>(s_copy) = stride_nwc[0]; + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride ((W,H,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + assert(stride_nhwc[3] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_nhwc[2-i]; + }); + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride ((W,H,D,N),_1) +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Int<1>> +make_cute_packed_stride( + cute::Stride, cute::Int<1>> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<0,i>(s_copy) = stride_ndhwc[3-i]; + }); + return s_copy; +} + +// +// Filter tensor (k, (_1, s, r, t)) for fprop kernel +// + +// Filter cutlass::layout::TensorNWC -> rank-2 stride (k, (_1, s)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>> +make_cute_packed_stride( + cute::Stride, IntT>> s, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorNHWC -> rank-2 stride (k, (_1, s, r)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>> s, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride (k, (_1, s, r, t)) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>> s, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + +// +// Activation tensor (_1, (w, h, d, n)) for wgrad kernel +// +// It is also Filter tensor ((_1), (k, s, r, t)) for dgrad kernel +// + +// Activation cutlass::layout::TensorNWC -> rank-2 stride (_1, (W,N)) in wgrad +// Filter cutlass::layout::TensorNWC -> rank-2 stride ((_1), (k, s)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nwc[2] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::get<1,0>(s_copy) = stride_nwc[1]; + cute::get<1,1>(s_copy) = stride_nwc[0]; + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nwc in dgrad is ksc. + cute::get<1,0>(s_copy) = stride_nwc[0]; + cute::get<1,1>(s_copy) = stride_nwc[1]; + } + return s_copy; +} + +// Activation cutlass::layout::TensorNHWC -> rank-2 stride (_1, (W,H,N)) in wgrad +// Filter cutlass::layout::TensorNHWC -> rank-2 stride ((_1), (k, s, r)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_nhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nhwc[3] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_nhwc[2-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_nhwc in dgrad is krsc. + cute::get<1,0>(s_copy) = stride_nhwc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_nhwc[i+1]; + }); + } + return s_copy; +} + +// Activation cutlass::layout::TensorNDHWC -> rank-2 stride (_1, (W,H,D,N)) in wgrad +// Filter cutlass::layout::TensorNDHWC -> rank-2 stride ((_1), (k, s, r, t)) in dgrad +template +CUTLASS_HOST_DEVICE +cute::Stride, cute::Stride> +make_cute_packed_stride( + cute::Stride, cute::Stride> s, + cute::array stride_ndhwc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ndhwc[4] == 1); + auto s_copy = s; + if (ConvOp == cutlass::conv::Operator::kWgrad) { + cute::for_each(cute::make_seq<4>{}, [&](auto i) { + cute::get<1,i>(s_copy) = stride_ndhwc[3-i]; + }); + } + else if (ConvOp == cutlass::conv::Operator::kDgrad) { + // stride_ndhwc in dgrad is ktrsc. + cute::get<1,0>(s_copy) = stride_ndhwc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ndhwc[i+1]; + }); + } + return s_copy; +} + +// +// NZPQ tensor (_1, nzpq) for wgrad kernel +// + +// cutlass::layout::TensorNWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nqk[2] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nqk[1]; + return s_copy; +} + +// cutlass::layout::TensorNHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_npqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_npqk[3] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_npqk[2]; + return s_copy; +} + +// cutlass::layout::TensorNDHWC -> rank-2 stride (_1, nzpq) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT> +make_cute_packed_stride( + cute::Stride, IntT> s, + cute::array stride_nzpqk, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_nzpqk[4] == 1); + auto s_copy = s; + cute::get<1>(s_copy) = stride_nzpqk[3]; + return s_copy; +} + + + +// +// Wgrad output tensor (k, (_1, s, r, t), _0) +// + +// Filter cutlass::layout::TensorKCS -> rank-3 stride (k, (_1, s), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ksc[0]; + cute::get<1,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorKCSR -> rank-3 stride (k, (_1, s, r), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<1,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorKCSRT -> rank-3 stride (k, (_1, s, r, t), _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<0,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<1,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} + + +// +// Wgrad output tensor ((_1, s, r, t), k, _0) +// + +// Filter cutlass::layout::TensorCSK -> rank-3 stride ((_1, s), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ksc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ksc[2] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ksc[0]; + cute::get<0,1>(s_copy) = stride_ksc[1]; + return s_copy; +} + +// Filter cutlass::layout::TensorCSRK -> rank-3 stride ((_1, s, r), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_krsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_krsc[3] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_krsc[0]; + cute::for_each(cute::make_seq<2>{}, [&](auto i) { + cute::get<0,2-i>(s_copy) = stride_krsc[i+1]; + }); + return s_copy; +} + +// Filter cutlass::layout::TensorCSRTK -> rank-3 stride ((_1, s, r, t), k, _0) +template +CUTLASS_HOST_DEVICE +cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> +make_cute_packed_stride( + cute::Stride, IntT, IntT, IntT>, IntT, cute::Int<0>> s, + [[maybe_unused]] cute::array shape_output, + cute::array stride_ktrsc, + conv::Operator ConvOp) { + static_assert(std::is_integral_v, + "Stride must have an integral type so it can be set dynamically. Static strides not supported."); + + assert(stride_ktrsc[4] == 1); + auto s_copy = s; + cute::get<1,0>(s_copy) = stride_ktrsc[0]; + cute::for_each(cute::make_seq<3>{}, [&](auto i) { + cute::get<0,3-i>(s_copy) = stride_ktrsc[i+1]; + }); + return s_copy; +} +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c38ad3f710c18e5be1bb7e01dc66d7efcd2646d9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/print_error.hpp @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include + +#include +#include + +#include + +// The computed infinity norm does not include +// any NaN column absolute-value sums. +struct matrix_inf_norm_result { + // Accumulate errors in double, as this is generally + // the highest precision that the examples use. + double inf_norm = 0.0; + bool found_nan = false; +}; + +// In theory, cute::Tensor, T> could be treated as a view type, +// and thus passed by value (as std::span or std::string_view would be). +// However, generic cute::Tensor are more like containers +// and thus are best passed by reference or const reference. +template +matrix_inf_norm_result +matrix_inf_norm(cute::Tensor const& host_matrix) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + error_type inf_norm = 0.0; + bool found_nan = false; + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(host_matrix); + const int64_t num_cols = cute::size<1>(host_matrix); + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for(int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += abs_fn(host_matrix(i, j)); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +// Infinity norm of (X - Y). +template +matrix_inf_norm_result +matrix_diff_inf_norm(cute::Tensor const& X, + cute::Tensor const& Y) +{ + using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; + + auto abs_fn = [] (element_type A_ij) { + if constexpr (not std::is_unsigned_v) { + using std::abs; + return abs(A_ij); + } + else { + return A_ij; + } + }; + + assert(cute::size<0>(X) == cute::size<0>(Y)); + assert(cute::size<1>(X) == cute::size<1>(Y)); + + // Computing the infinity norm requires that we be able + // to treat the input as a matrix, with rows and columns. + const int64_t num_rows = cute::size<0>(X); + const int64_t num_cols = cute::size<1>(X); + + error_type inf_norm = 0.0; + bool found_nan = false; + + for (int64_t i = 0; i < num_rows; ++i) { + error_type row_abs_sum = 0.0; + for (int64_t j = 0; j < num_cols; ++j) { + row_abs_sum += error_type(abs_fn(element_type(X(i,j)) - + element_type(Y(i,j)))); + } + if (std::isnan(row_abs_sum)) { + found_nan = true; + } + else { + inf_norm = row_abs_sum > inf_norm ? row_abs_sum : inf_norm; + } + } + + return {inf_norm, found_nan}; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + char const A_value_type_name[], + cute::Tensor const& A, + char const B_value_type_name[], + cute::Tensor const& B, + char const C_value_type_name[], + cute::Tensor const& C, + cute::Tensor const& C_ref) +{ + const auto [A_norm, A_has_nan] = matrix_inf_norm(A); + const auto [B_norm, B_has_nan] = matrix_inf_norm(B); + const auto [C_norm, C_has_nan] = matrix_inf_norm(C_ref); + const auto [diff_norm, diff_has_nan] = matrix_diff_inf_norm(C, C_ref); + + const auto A_norm_times_B_norm = A_norm * B_norm; + const auto relative_error = A_norm_times_B_norm == 0.0 ? + diff_norm : (diff_norm / A_norm_times_B_norm); + + // For expected error bounds, please refer to the LAPACK Users' Guide, + // in particular https://netlib.org/lapack/lug/node108.html . + // Printing the infinity norm of C is a way to check + // that both the function being tested (C) + // and the reference implementation (C_ref) + // don't just do nothing (or fill with zeros). + using std::cout; + using cute::shape; + cout << "Matrix A: " << shape<0>(A) << "x" << shape<1>(A) << " of " << A_value_type_name << '\n' + << "Matrix B: " << shape<0>(B) << "x" << shape<1>(B) << " of " << B_value_type_name << '\n' + << "Matrix C: " << shape<0>(C) << "x" << shape<1>(C) << " of " << C_value_type_name << '\n' + << std::scientific + << "Infinity norm of A: " << A_norm << '\n' + << "Infinity norm of B: " << B_norm << '\n' + << "Infinity norm of C: " << C_norm << '\n' + << "Infinity norm of (C - C_ref): " << diff_norm << '\n'; + + if(A_norm_times_B_norm == 0.0) { + cout << "Mollified relative error: " << relative_error << '\n'; + } else { + cout << "Relative error: " << relative_error << '\n'; + } + + if (A_has_nan || B_has_nan || C_has_nan || diff_has_nan) { + cout << "Did we encounter NaN in A? " << (A_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in B? " << (B_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in C? " << (C_has_nan ? "yes" : "no") << '\n' + << "Did we encounter NaN in (C - C_ref)? " << (diff_has_nan ? "yes" : "no") << '\n'; + } + return relative_error; +} + +template +auto +print_matrix_multiply_mollified_relative_error( + const char value_type_name[], + const cute::Tensor& A, + const cute::Tensor& B, + const cute::Tensor& C_computed, + const cute::Tensor& C_expected) +{ + return print_matrix_multiply_mollified_relative_error(value_type_name, A, value_type_name, B, + value_type_name, C_computed, C_expected); +} + +// Take a CUTLASS HostTensor (or the like) as input, +// and return a const CuTe Tensor. +// This is useful for use with the above error printing functions. +// This implicitly "transposes" if the layout is RowMajor. +// Note that the HostTensor must be captured by nonconst reference +// in order for X.host_ref().data() to compile. +// (CUTLASS is a bit more container-y than CuTe.) +template +auto host_matrix_to_const_cute_tensor(CutlassHostTensorType& X) +{ + // The tensors were created with post-transposed extents. + const auto extents = X.extent(); + const auto shape = cute::Shape{extents[0], extents[1]}; + // Both RowMajor and ColumnMajor only store one stride. + const int LDX = X.stride(0); + const auto strides = [&]() { + using input_layout_type = typename std::decay_t::Layout; + if constexpr (std::is_same_v) { + return cute::Stride{1, LDX}; + } + else { + static_assert(std::is_same_v); + return cute::Stride{LDX, 1}; + } + }(); + const auto layout = cute::make_layout(shape, strides); + auto X_data = X.host_ref().data(); + auto X_data_const = const_cast >(X_data); + return cute::make_tensor(X_data_const, layout); +}; + + +// Returns EXIT_SUCCESS if the 2-norm relative error is exactly zero, else returns EXIT_FAILURE. +// This makes the return value suitable as the return value of main(). +template +int +print_relative_error( + std::size_t n, + T1 const& data, + T2 const& reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + using std::abs; using std::sqrt; + + // Use either double or complex for error computation + using value_type = cute::remove_cvref_t; + using error_type = std::conditional_t::value, + cute::complex, + double>; + + if (print_verbose) { + std::cout << "Idx:\t"<< "Val\t" << "RefVal\t" << "RelError" << std::endl; + } + + double eps = 1e-200; + + double tot_error_sq = 0; + double tot_norm_sq = 0; + double tot_ind_rel_err = 0; + double max_ind_rel_err = 0; + double max_diff = 0; + for (std::size_t i = 0; i < n; ++i) { + error_type val = data[i]; + error_type ref = reference[i]; + + double aref = abs(ref); + double diff = abs(ref - val); + double rel_error = diff / (aref + eps); + + // Individual relative error + tot_ind_rel_err += rel_error; + + // Maximum relative error + max_ind_rel_err = std::max(max_ind_rel_err, rel_error); + + // Maximum delta in value error + max_diff = std::max(max_diff, diff); + + // Total relative error + tot_error_sq += diff * diff; + tot_norm_sq += aref * aref; + + if (print_verbose) { + std::cout << i << ":\t" << val << "\t" << ref << "\t" << rel_error << std::endl; + } + } + + double ave_rel_err = tot_ind_rel_err / double(n); + if (print_error) { + printf("Average relative error: %.3e\n", ave_rel_err); + } + + if (print_error) { + printf("Maximum relative error: %.3e\n", max_ind_rel_err); + } + + if (print_error) { + printf("Maximum difference : %.3e\n", max_diff); + } + + double tot_rel_err = sqrt(tot_error_sq/(tot_norm_sq+eps)); + if (print_error) { + printf("Vector relative error: %.3e\n", tot_rel_err); + } + + printf("Vector reference norm: %.3e\n", sqrt(tot_norm_sq)); + + return (tot_rel_err <= error_margin) ? EXIT_SUCCESS : EXIT_FAILURE; +} + +// Overload for cute::Tensor<> +template +int +print_relative_error( + cute::Tensor data, + cute::Tensor reference, + bool print_verbose = false, + bool print_error = true, + double error_margin = 0.00001) { + assert(size(data) == size(reference)); + return print_relative_error(static_cast(size(data)), + data, reference, + print_verbose, print_error, error_margin); +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h new file mode 100644 index 0000000000000000000000000000000000000000..8167c91bf2330d160a78ba210449357b395964ca --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/inner_product.h @@ -0,0 +1,135 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +namespace cutlass { +namespace reference { +namespace detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template function to compute an inner product. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate with a + // host-only type +template +CUTLASS_HOST_DEVICE +Ctype inner_product(Atype a, Btype b, Ctype c) { + return Ctype(a) * Ctype(b) + c; +} + +/// Specialization for matrix multiplication with binary operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int bit = 0; bit < 32; bit++) { + accum += a[bit] ^ b[bit]; + } + return accum + c; +} + +/* +/// Specialization for matrix multiplication with signed 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} + +/// Specialization for matrix multiplication with unsigned 4-bit integer operands +template <> +CUTLASS_HOST_DEVICE +int inner_product, Array, int>( + Array a, + Array b, + int c) { + + int accum = 0; + for (int k = 0; k < 8; k++) { + accum += a[k] * b[k]; + } + return accum + c; +} +*/ + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Cast { + // Default behavior: convert to the destination type +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + static DstType apply(SrcType src) { return static_cast(src); }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static int8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(-128.f, fminf(127.f, src))); + }; +}; + +template <> +struct Cast { + CUTLASS_HOST_DEVICE + static uint8_t apply(float src) { + // Clamp to the range of signed 8-bit integers. + return static_cast(fmaxf(0.f, fminf(255.f, src))); + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h new file mode 100644 index 0000000000000000000000000000000000000000..652d622586cb202ecfe69ac892978b649b5d1be7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/detail/linear_to_coordinate.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + + int64_t prod = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - Index; i < Rank; ++i) { + prod *= int64_t(extent[i]); + } + + coord[Rank - Index - 1] = int(idx / prod); + + int64_t residual = idx % prod; + LinearToCoordinateHelper()(coord, residual, extent); + } +}; + +template +struct LinearToCoordinateHelper { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &) const { + coord[Rank - 1] = int(idx); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct LinearToCoordinate { + + CUTLASS_HOST_DEVICE + void operator()(Coord &coord, int64_t idx, Coord const &extent) const { + LinearToCoordinateHelper()(coord, idx, extent); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..7c6f803c47f5c407cf058d40bc8274a448a36dc4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/convolution.h @@ -0,0 +1,1549 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Conv2d device reference kernel +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + int c_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Get group id of currnet channel + int c_group_idx = C / c_per_group; + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + int k_group_idx = thread_k / k_per_group; + + if (thread_k < problem_size.K && k_group_idx == c_group_idx) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, R, S, C % c_per_group})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d Fprop kernel - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nzpq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_z[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, Z, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + int64_t ZPQ = PQ * problem_size.Z; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nzpq = nzpq_start + m; + + thread_n[m] = int(nzpq / ZPQ); + + int64_t residual = nzpq % ZPQ; + thread_z[m] = int(residual / PQ); + + residual = residual % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int C = 0; C < problem_size.C; ++C) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int d = thread_z[m] * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = thread_p[m] * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = thread_q[m] * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (thread_n[m] < problem_size.N && + d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + element_A[m] = ElementAccumulator(tensor_x.at({thread_n[m], d, h, w, C})); + } + else { + element_A[m] = ElementAccumulator(); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + + if (thread_k < problem_size.K) { + element_B[n] = ElementAccumulator(tensor_w.at({thread_k, T, R, S, C})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_z[m] < problem_size.Z && + thread_p[m] < problem_size.P && + thread_q[m] < problem_size.Q) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_y_in.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k})); + } + + tensor_y_out.at({thread_n[m], thread_z[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } // for (n) + + } + } // for (m) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t nhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t nhw = nhw_start + m; + + thread_n[m] = int(nhw / HW); + + int64_t residual = nhw % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (p >= 0 && !(p % problem_size.stride_h) && q >= 0 && !(q % problem_size.stride_w)) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && thread_h[m] < problem_size.H && thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d dgrad kernel - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int64_t ndhw_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int c_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_d[kThreadM]; + int thread_h[kThreadM]; + int thread_w[kThreadM]; + + // Compute N, H, W coordinates for each row of a thread's tile + int64_t HW = int64_t(problem_size.H) * problem_size.W; + int64_t DHW = HW * problem_size.D; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t ndhw = ndhw_start + m; + + thread_n[m] = int(ndhw / DHW); + + int64_t residual = ndhw % DHW; + thread_d[m] = int(residual / HW); + + residual = residual % HW; + thread_h[m] = int(residual / problem_size.W); + thread_w[m] = int(residual % problem_size.W); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int T = 0; T < problem_size.T; ++T) { + for (int R = 0; R < problem_size.R; ++R) { + for (int S = 0; S < problem_size.S; ++S) { + for (int K = 0; K < problem_size.K; ++K) { + + // Load from activations tensor + int filter_t = T; + int filter_r = R; + int filter_s = S; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - T; + filter_r = problem_size.R - 1 - R; + filter_s = problem_size.S - 1 - S; + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int z = thread_d[m] + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = thread_h[m] + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = thread_w[m] + problem_size.pad_w - filter_s * problem_size.dilation_w; + + element_A[m] = ElementAccumulator(); + + if (z >= 0 && !(z % problem_size.stride_d) && + p >= 0 && !(p % problem_size.stride_h) && + q >= 0 && !(q % problem_size.stride_w)) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (thread_n[m] < problem_size.N && z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + element_A[m] = ElementAccumulator(tensor_dy.at({thread_n[m], z, p, q, K})); + } + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + + if (thread_c < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_w.at({K, T, R, S, thread_c})); + } + else { + element_B[n] = ElementAccumulator(); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + if (thread_n[m] < problem_size.N && + thread_d[m] < problem_size.D && + thread_h[m] < problem_size.H && + thread_w[m] < problem_size.W) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_c = c_start + n; + if (thread_c < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dx_in.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c})); + } + + tensor_dx_out.at({thread_n[m], thread_d[m], thread_h[m], thread_w[m], thread_c}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Conv2d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t rsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t rsc = rsc_start + n; + int64_t residual = rsc % SC; + + thread_r[n] = int(rsc / SC); + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W && thread_c[n] < problem_size.C) { + element_B[n] = ElementAccumulator(tensor_x.at({N, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + } + } + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_r[n] < problem_size.R && thread_s[n] < problem_size.S && thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +// Conv3d wgrad kernel - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kThreadM = 2, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 8, // shape of a threadblock in units of threads + int kCtaShapeN = 16 // shape of a threadblock in units of threads +> +__global__ void Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta + ) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + ElementAccumulator element_A[kThreadM]; + ElementAccumulator element_B[kThreadN]; + ElementAccumulator accum[kThreadM][kThreadN]; + + int k_start = blockIdx.x * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int64_t trsc_start = int64_t(blockIdx.y) * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_t[kThreadN]; + int thread_r[kThreadN]; + int thread_s[kThreadN]; + int thread_c[kThreadN]; + + // Compute R, S, C coordinates for each row of a thread's tile + int64_t SC = int64_t(problem_size.S) * problem_size.C; + int64_t RSC = SC * problem_size.R; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + int64_t trsc = trsc_start + n; + + thread_t[n] = int(trsc / RSC); + + int64_t residual = trsc % RSC; + thread_r[n] = int(residual / SC); + + residual = residual % SC; + thread_s[n] = int(residual / problem_size.C); + thread_c[n] = int(residual % problem_size.C); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = ElementAccumulator(); + } + } + + // Compute convolution + for (int N = 0; N < problem_size.N; ++N) { + for (int Z = 0; Z < problem_size.Z; ++Z) { + for (int P = 0; P < problem_size.P; ++P) { + for (int Q = 0; Q < problem_size.Q; ++Q) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + element_A[m] = ElementAccumulator(); + + if (thread_k < problem_size.K) { + element_A[m] = ElementAccumulator(tensor_dy.at({N, Z, P, Q, thread_k})); + } + } + + // Load from filters tensor + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + // Load from activations tensor + int filter_t = thread_t[n]; + int filter_r = thread_r[n]; + int filter_s = thread_s[n]; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - filter_t; + filter_r = problem_size.R - 1 - filter_r; + filter_s = problem_size.S - 1 - filter_s; + } + + int d = Z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = P * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = Q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + element_B[n] = ElementAccumulator(); + + if (d >= 0 && d < problem_size.D && + h >= 0 && h < problem_size.H && + w >= 0 && w < problem_size.W && + thread_c[n] < problem_size.C) { + + element_B[n] = ElementAccumulator(tensor_x.at({N, d, h, w, thread_c[n]})); + } + } + + // Accumulate matrix product + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + accum[m][n] = inner_product_op(element_A[m], element_B[n], accum[m][n]); + } + } + + } // for (Q) + } // for (P) + } // for (Z) + } // for (N) + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + int thread_k = k_start + m; + + if (thread_k < problem_size.K) { + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + + if (thread_t[n] < problem_size.T && + thread_r[n] < problem_size.R && + thread_s[n] < problem_size.S && + thread_c[n] < problem_size.C) { + + ElementCompute c_ref = ElementCompute(); + + if (beta != ElementCompute()) { + c_ref = ElementCompute(tensor_dw_in.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]})); + } + + tensor_dw_out.at({thread_k, thread_t[n], thread_r[n], thread_s[n], thread_c[n]}) = convert_op( + alpha * ElementCompute(accum[m][n]) + beta * c_ref); + } + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Conv2d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Fprop dispatcher - y = fprop(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nzpq = int64_t(problem_size.N) * problem_size.Z * problem_size.P * problem_size.Q; + int64_t blocks_m = (nzpq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dFprop< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_x, + tensor_w, + tensor_y_in, + tensor_y_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dDgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t nhw = int64_t(problem_size.N) * problem_size.H * problem_size.W; + int64_t blocks_m = (nhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv2dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Dgrad dispatcher - dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dDgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t ndhw = int64_t(problem_size.N) * problem_size.D * problem_size.H * problem_size.W; + int64_t blocks_m = (ndhw + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.C + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + kernel::Conv3dDgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_w, + tensor_dx_in, + tensor_dx_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv2d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2dWgrad( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t rsc = int64_t(problem_size.R) * problem_size.S * problem_size.C; + int64_t blocks_n = (rsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv2dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +/// Conv3d Wgrad dispatcher - dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3dWgrad( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + // + // Blocking factors improve performance of reference implementation + // + + int const kThreadM = 2; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 8; // shape of a threadblock in units of threads + int const kCtaShapeN = 16; // shape of a threadblock in units of threads + + int64_t trsc = int64_t(problem_size.T) * problem_size.R * problem_size.S * problem_size.C; + int64_t blocks_n = (trsc + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid((problem_size.K + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM), uint32_t(blocks_n)); + + kernel::Conv3dWgrad< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, + InnerProductOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block, 0, stream >>>( + problem_size, + tensor_dy, + tensor_x, + tensor_dw_in, + tensor_dw_out, + alpha, + beta + ); + + cudaError_t result = cudaPeekAtLastError(); + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + + return Status::kSuccess; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kDgrad: + return Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + case conv::Operator::kWgrad: + return Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + break; + + default: break; + } + + return Status::kErrorNotSupported; +} + +/// Generic 3D convolution targeting Conv3dFprop, Conv3dDgrad, and Conv3dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +Status Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta, + cudaStream_t stream = nullptr) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + return Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kDgrad: + return Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + case conv::Operator::kWgrad: + return Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, stream); + + default: break; + } + + return Status::kErrorNotSupported; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..7d575d522c1dd87d51f9bc58d09786393c5cfea3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm.h @@ -0,0 +1,385 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/kernel/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + // Launch a GEMM kernel + kernel::Gemm< + TensorRef, + TensorRef, + TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum) { + + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + AccumulatorType initial_accum = AccumulatorType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename InnerProductOp, + typename ConvertOp +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + static_assert( + TensorRefCollectionA::kRank == 2 && + TensorRefCollectionB::kRank == 2 && + TensorRefCollectionC::kRank == 2, "Tensors must be of rank 2"); + + // Blocking structure potentially improves performance of reference implementation + // with a minor increase in complexity. + // + // Note, this reference implementation is NOT expected to approach peak performance. + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn), + batch_count + ); + + // Launch a GEMM kernel + kernel::BatchedGemm< + TensorRefCollectionA, + TensorRefCollectionB, + TensorRefCollectionC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + tensor_b, + beta, + tensor_c, + initial_accum + ); +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..bddf596214da62a7aa3177f758db3710dc1d2516 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_complex.h @@ -0,0 +1,350 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + if (grid.y <= std::numeric_limits::max()) { + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } else { + // Using bigger thread tile size + int const kBigMblock = 4; + int const kBigNblock = 16; + + dim3 Bigblock(16, 8); + dim3 Biggrid( + (problem_size.m() + block.x * kBigMblock - 1) / (block.x * kBigMblock), + (problem_size.n() + block.y * kBigNblock - 1) / (block.y * kBigNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::GemmComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ElementD, + ConvertOp, + InnerProductOp, + kBigMblock, + kBigNblock + ><<< Biggrid, Bigblock >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..48819cf6eaa565b3ec41dbbf78ae244666fd8a65 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gemm_planar_complex.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +static int const kGemmPlanarComplexBlockSize = 4; + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +__global__ void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + int const kMblock = kGemmPlanarComplexBlockSize; + int const kNblock = kGemmPlanarComplexBlockSize; + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + complex accum[kMblock][kNblock]; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + CUTLASS_PRAGMA_NO_UNROLL + for (int k_block = 0; k_block < K; ++k_block) { + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC c_ij = ComplexC(); + + if (beta.real() != ScalarType() || beta.imag() != ScalarType()) { + c_ij = tensor_c.at(coord); + } + + complex src{ + ScalarType(c_ij.real()), + ScalarType(c_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + ComplexC d_ij; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = kernel::kGemmPlanarComplexBlockSize; + int const kNblock = kernel::kGemmPlanarComplexBlockSize; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + 1); + + kernel::GemmPlanarComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..497a257d170c411d891942f62fa2c960453d03d5 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/gett.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GETT device reference code +*/ +#pragma once + +#include + +namespace cutlass::reference::device { + +template < + class ATensor, + class BTensor, + class CTensor, + class DTensor, + class ElementAccumulator, + class ElementEpilogue> +__global__ static +void +gett_kernel( + DTensor D, + ATensor const A, + BTensor const B, + CTensor const C, + ElementEpilogue alpha, ElementEpilogue beta, + ElementAccumulator acc_init) +{ + using namespace cute; + + static_assert(DTensor::rank == 3, "(M,N,L)"); + static_assert(ATensor::rank == 3, "(M,K,L)"); + static_assert(BTensor::rank == 3, "(N,K,L)"); + static_assert(CTensor::rank == 3, "(M,N,L)"); + + assert(size<0>(A) == size<0>(D)); // M + assert(size<0>(C) == size<0>(D)); // M + assert(size<0>(B) == size<1>(D)); // N + assert(size<1>(C) == size<1>(D)); // N + assert(size<1>(A) == size<1>(B)); // K + assert(size<2>(A) == size<2>(D)); // L + assert(size<2>(B) == size<2>(D)); // L + assert(size<2>(C) == size<2>(D)); // L + + NumericConverter a_converter; + NumericConverter b_converter; + NumericConverter acc_converter; + NumericConverter source_converter; + NumericConverter output_converter; + + // Thread id to each element of D + for (int tid = threadIdx.x + blockDim.x * blockIdx.x; + tid < size(D); + tid += blockDim.x * gridDim.x) { + // (m,n,l) coordinate + auto mnl_coord = idx2crd(tid, product_each(shape(D))); + auto m = get<0>(mnl_coord); + auto n = get<1>(mnl_coord); + auto l = get<2>(mnl_coord); + + auto A_ml = A(m,_,l); + auto B_nl = B(n,_,l); + + ElementAccumulator accum = ElementAccumulator(0); + for (int k = 0; k < size<1>(A); ++k) { + ElementAccumulator a = a_converter(A_ml(k)); + ElementAccumulator b = b_converter(B_nl(k)); + accum += a * b; + } + + ElementEpilogue scaled_output = (alpha * acc_converter(accum)) + (beta * source_converter(C(m,n,l))); + D(m,n,l) = output_converter(scaled_output); + } +} + +// Most general version +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +void +gett( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4); + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto K = get<2>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full tensors + auto A = make_tensor(make_gmem_ptr(ptr_A), make_shape(M,K,L), stride_a_mkl); // (M,K,L) + auto B = make_tensor(make_gmem_ptr(ptr_B), make_shape(N,K,L), stride_b_nkl); // (N,K,L) + auto C = make_tensor(make_gmem_ptr(ptr_C), make_shape(M,N,L), stride_c_mnl); // (M,N,L) + auto D = make_tensor(make_gmem_ptr(ptr_D), make_shape(M,N,L), stride_d_mnl); // (M,N,L) + + dim3 dimBlock(256); + dim3 dimGrid(240); + gett_kernel<<< dimGrid, dimBlock, 0, stream >>>(D, A, B, C, alpha, beta, ElementAccumulator(0)); +} + +} // namespace cutlass::reference::device diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..6e131126a336420a2b0e843e3ead3d89fce637fa --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/gemm.h @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/device/thread/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void Gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefA tensor_a, + TensorRefB tensor_b, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + AccumulatorType initial_accum) { + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Compute the general matrix product + thread::Gemm< + TensorRefA, + TensorRefB, + TensorRefC, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, tensor_d, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp, + typename ConvertOp +> +__global__ void BatchedGemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRefCollectionA tensor_collection_a, + TensorRefCollectionB tensor_collection_b, + ScalarType beta, + TensorRefCollectionC tensor_collection_c, + AccumulatorType initial_accum) { + + // Obtain batch ID + int batch_id = blockIdx.z; + + // Dereference based on batch_id + typename TensorRefCollectionA::TensorRef tensor_a = tensor_collection_a.at(batch_id); + typename TensorRefCollectionB::TensorRef tensor_b = tensor_collection_b.at(batch_id); + typename TensorRefCollectionC::TensorRef tensor_c = tensor_collection_c.at(batch_id); + + // Map each thread to a unique tile of the output matrix + MatrixCoord output_coord( + (threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kColumn, + (threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kRow + ); + + // Compute the general matrix product + thread::Gemm< + typename TensorRefCollectionA::TensorRef, + typename TensorRefCollectionB::TensorRef, + typename TensorRefCollectionC::TensorRef, + ScalarType, + AccumulatorType, + OutputTile, + InnerProductOp, + ConvertOp + > gemm(initial_accum); + + gemm.multiply_add( + problem_size, + tensor_a, + tensor_b, + output_coord); + + gemm.epilogue(problem_size, alpha, beta, tensor_c, output_coord); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..149e4b2e00e2ac8130cee9dc189a539ba3a70297 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_elementwise.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform random distribution +template +__global__ void TensorInitializeUniform( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + double range = dist.uniform.max - dist.uniform.min; + + double rnd = curand_uniform(&rng_state[threadIdx.x]); + + rnd = dist.uniform.min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + + tensor += ldm; + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel to initialize tensor to uniform distribution +template +__global__ void TensorInitializeGaussian( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + double rnd = curand_normal(&rng_state[threadIdx.x]); + + rnd = dist.gaussian.mean + dist.gaussian.stddev * rnd; + + if (dist.int_scale >= 0) { + rnd = double(int(rnd * double(1 << dist.int_scale))); + *tensor = T(rnd / double(1 << dist.int_scale)); + } else { + *tensor = T(rnd); + } + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeLinear( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = + dist.linear.offset + dist.linear.delta_row * c_idx + dist.linear.delta_column * s_idx; + } + } +} + +/// Kernel to initialize tensor to an identity matrix +template +__global__ void TensorInitializeIdentity( + Distribution dist, int64_t seed, int dim_contiguous, int dim_strided, T *tensor, int ldm) { + __shared__ curandState_t rng_state[1024]; + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x + blockIdx.y * gridDim.x * blockDim.x; + + curand_init(seed, gtid, 0, &rng_state[threadIdx.x]); + + int c_idx = blockIdx.x * blockDim.x + threadIdx.x; + int s_idx = blockIdx.y * blockDim.x; + + tensor += s_idx * ldm + c_idx; + + for (int s_offset = 0; s_offset < blockDim.x; ++s_offset, ++s_idx) { + if (s_idx < dim_strided && c_idx < dim_contiguous) { + *tensor = (c_idx == s_idx ? T(1) : T(0)); + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..3223cb2056ba6d88f47f7b117392a56e325d0ce7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/kernel/tensor_foreach.h @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/fast_math.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace kernel { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for general rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + int64_t product = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = Rank - RankRemaining; i < Rank; ++i) { + product *= size[i]; + } + + coord[Rank - 1 - RankRemaining] = index / product; + int64_t remaining = index % product; + + TensorForEachHelper(func, size, coord, remaining); + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Constructor for fastest changing rank + __inline__ __device__ + TensorForEachHelper(Func &func, Coord const &size, Coord &coord, int64_t index) { + + coord[Rank - 1] = index; + + if (coord < size) { + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element in a tensor's index space +template +__global__ void TensorForEach(Coord size, Params params = Params()) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + int64_t max_index = 1; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + max_index *= size[i]; + } + + CUTLASS_PRAGMA_NO_UNROLL + while (index < max_index) { + Coord coord; + + detail::TensorForEachHelper(func, size, coord, index); + index += blockDim.x * gridDim.x; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Kernel calls a functor for each element along a tensor's diagonal +template +__global__ void TensorDiagonalForEach(Coord size, Params params, int start, int end) { + + Func func(params); + + int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start; + + if (index < end) { + Coord coord; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Rank; ++i) { + coord[i] = index; + } + + func(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params) { + + Func func(params); + + size_t index = threadIdx.x + blockIdx.x * blockDim.x; + + for (; index < capacity; index += blockDim.x * gridDim.x) { + ReferenceFactory::get(ptr, index) = func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace device +} // namespace reference +} // namespace cutlass + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..2e76fe52b06f9bb1a033c736f94fa01961ce664d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/rank_2k_complex.h @@ -0,0 +1,355 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in device-side code. +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add, + int kMblock = 4, + int kNblock = 4 +> +__global__ void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + assert(M=N); + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_a.add_pointer_offset(batch_idx * batch_stride_A); + tensor_b.add_pointer_offset(batch_idx * batch_stride_B); + tensor_c.add_pointer_offset(batch_idx * batch_stride_C); + tensor_d.add_pointer_offset(batch_idx * batch_stride_D); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + + // Compute matrix product using blocks + ComputeType accum[kMblock][kNblock]; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + tensor_a.add_pointer_offset(batch_stride_A * gridDim.z); + tensor_b.add_pointer_offset(batch_stride_B * gridDim.z); + tensor_c.add_pointer_offset(batch_stride_C * gridDim.z); + tensor_d.add_pointer_offset(batch_stride_D * gridDim.z); + + } // for (batch_idx) +} + +} // namespace kernel + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::Rank2KComplex< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ScalarType, + ComputeType, + ConvertOp, + InnerProductOp, + kMblock, + kNblock + ><<< grid, block >>>( + problem_size, + alpha, + tensor_a, + transform_a, + tensor_b, + transform_b, + beta, + tensor_c, + tensor_d, + initial_accum, + fill_mode_c, + blas_mode, + batch_count, + batch_stride_A, + batch_stride_B, + batch_stride_C, + batch_stride_D + ); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..1999730f6d24e69aef152aa332fae68af57a9c40 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_compare.h @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" + +#include "cutlass/util/distribution.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template +__global__ void BlockCompareEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (a != b) { + *equal = 0; + + return; + } + } +} + +template +__global__ void BlockCompareRelativelyEqual( + int *equal, + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor) { + + size_t idx = threadIdx.x + blockDim.x * blockIdx.x; + + for (; idx < capacity; idx += gridDim.x * blockDim.x) { + + Element a = cutlass::ReferenceFactory::get(ptr_A, idx); + Element b = cutlass::ReferenceFactory::get(ptr_B, idx); + + if (!relatively_equal(a, b, epsilon, nonzero_floor)) { + *equal = 0; + return; + } + } +} + +} // namespace kernel + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareEqual<<< grid, block, 0, stream >>>(device_equal_flag, ptr_A, ptr_B, capacity); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Performs a bit-level equality check between two blocks +template +bool BlockCompareRelativelyEqual( + Element const *ptr_A, + Element const *ptr_B, + size_t capacity, + Element epsilon, + Element nonzero_floor, + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + int equal_flag = 1; + int *device_equal_flag = nullptr; + + if (cudaMalloc((void **)&device_equal_flag, sizeof(int)) != cudaSuccess) { + throw std::runtime_error("Failed to allocate device flag."); + } + + if (cudaMemcpy( + device_equal_flag, + &equal_flag, + sizeof(int), + cudaMemcpyHostToDevice) != cudaSuccess) { + + throw std::runtime_error("Failed to copy equality flag to device."); + } + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockCompareRelativelyEqual)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockCompareRelativelyEqual<<< grid, block, 0, stream >>>( + device_equal_flag, + ptr_A, + ptr_B, + capacity, + epsilon, + nonzero_floor + ); + + cudaStreamSynchronize(stream); + + if (cudaMemcpy( + &equal_flag, + device_equal_flag, + sizeof(int), + cudaMemcpyDeviceToHost) != cudaSuccess) { + + cudaFree(device_equal_flag); + + throw std::runtime_error("Failed to copy equality flag from device."); + } + + cudaFree(device_equal_flag); + + return equal_flag; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // device +} // reference +} // cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..a19b42825f6efb4a39466fe1cfc182ab7d831079 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_fill.h @@ -0,0 +1,2075 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +#if !defined(__CUDACC_RTC__) + +// Standard Library includes +#include +#include +#include +#include +#include + +#endif + +// CUDA includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_view.h" +#include "cutlass/blas3.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/layout/vector.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" +#include "cutlass/util/distribution.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +CUTLASS_DEVICE +FloatType random_normal_float(curandState_t *state) { + return curand_normal(state); +} + +template <> +CUTLASS_DEVICE +double random_normal_float(curandState_t *state) { + return curand_normal_double(state); +} + +template +CUTLASS_DEVICE +FloatType random_uniform_float(curandState_t *state) { + return curand_uniform(state); +} + +template <> +CUTLASS_DEVICE +double random_uniform_float(curandState_t *state) { + return curand_uniform_double(state); +} + +template +struct RandomGaussianFunc { + + using FloatType = typename std::conditional<(sizeof(Element) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Element) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element mean_ = 0, + Element stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd = random_normal_float(&rng_state); + rnd = params.mean + params.stddev * rnd; + + Element result; + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd += FloatType(1); + } else { + rnd -= FloatType(1); + } + result = Element(rnd); + } + + return result; + } +}; + + +template +struct RandomGaussianFunc> { + + using Element = complex; + using FloatType = typename std::conditional<(sizeof(Real) > 4), double, float>::type; + using IntType = typename std::conditional<(sizeof(Real) > 4), int64_t, int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType mean; + FloatType stddev; + int int_scale; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Real mean_ = 0, + Real stddev_ = 1, + int int_scale_ = -1, + int exclude_zero_ = -1 + ): + seed(seed_), + mean(static_cast(mean_)), + stddev(static_cast(stddev_)), + int_scale(int_scale_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomGaussianFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + FloatType rnd_r = random_normal_float(&rng_state); + FloatType rnd_i = random_normal_float(&rng_state); + rnd_r = params.mean + params.stddev * rnd_r; + rnd_i = params.mean + params.stddev * rnd_i; + + Element result; + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r += FloatType(1); + } else { + rnd_r -= FloatType(1); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomGaussianFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomGaussianFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = typename RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomGaussianFunc(Params const ¶ms): params(params), random(params.random) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type mean = Element(0), ///< Gaussian distribution's mean + typename RealType::Type stddev = Element(1), ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + using Func = detail::TensorFillRandomGaussianFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, typename RandomFunc::Params(seed, mean, stddev, bits, exclude_zero)), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template ///< Element type +void BlockFillRandomGaussian( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type mean, ///< Gaussian distribution's mean + typename RealType::Type stddev, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomGaussianFunc; + + typename RandomFunc::Params params(seed, mean, stddev, bits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random uniform distribution +template ///< Element type +struct RandomUniformFunc { + + using FloatType = typename std::conditional< + (sizeof(Element) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Element) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType max; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + Element max_ = 1, + Element min = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max_) - static_cast(min)), + max(static_cast(max_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); // scale up to clamp low order bits + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + range = (min == Element(0)) ? range - FloatType(1): range; + max = (max_ == Element(0)) ? max - FloatType(1): max; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(NAN); + } + } + + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.max - params.range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd = FloatType(std::llround(rnd * params.float_scale_up)); + result = Element(rnd * params.float_scale_down); + } + else { + result = Element(rnd); + } + + if (params.exclude_zero >=0 && result == Element(0.0)) { + if (rnd > FloatType(0)) { + rnd = std::min(params.max, rnd + FloatType(1)); + } else { + rnd = std::max((params.max - params.range), rnd - FloatType(1)); + } + result = Element(rnd); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template +struct RandomUniformFunc> { + + using Element = complex; + + using FloatType = typename std::conditional< + (sizeof(Real) > 4), + double, + float>::type; + + using IntType = typename std::conditional< + (sizeof(Real) > 4), + int64_t, + int>::type; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + FloatType min; + int int_scale; + double pnan; + FloatType float_scale_up; + FloatType float_scale_down; + int exclude_zero; ///< If non-negative, excludes zeros + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + FloatType max = 1, + FloatType min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + int exclude_zero_ = -1 + ): + seed(seed_), + range(static_cast(max - min_)), + min(static_cast(min_)), + int_scale(int_scale_), + pnan(pnan_), + exclude_zero(exclude_zero_) { + + float_scale_up = FloatType(IntType(1) << int_scale); + float_scale_down = FloatType(1) / FloatType(IntType(1) << int_scale); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero >= 0) { + min = (min == FloatType(0)) ? min + FloatType(1): min; + range = (max == FloatType(0)) ? range - FloatType(1): range; + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomUniformFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + + // Draw random float in [0.0, 1.0] to determine if element should be NaN. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (params.pnan > 0 && (curand_uniform(&rng_state) < (params.pnan))) { + return Element(Real(NAN), Real(NAN)); + } + } + + FloatType rnd_r = random_uniform_float(&rng_state); + FloatType rnd_i = random_uniform_float(&rng_state); + + rnd_r = params.min + params.range * rnd_r; + rnd_i = params.min + params.range * rnd_i; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (params.int_scale >= 0) { + rnd_r = FloatType(std::llround(rnd_r * params.float_scale_up)); + rnd_i = FloatType(std::llround(rnd_i * params.float_scale_up)); + + result = { + Real(rnd_r * params.float_scale_down), + Real(rnd_i * params.float_scale_down) + }; + } + else { + result = Element(Real(rnd_r), Real(rnd_i)); + } + + if (params.exclude_zero >= 0 && + result.real() == Real(0.0) && + result.imag() == Real(0.0)) { + + if (rnd_r > FloatType(0)) { + rnd_r = std::min(params.min + params.range, rnd_r + FloatType(1)); + } else { + rnd_r = std::max((params.min), rnd_r - FloatType(1)); + } + result = Element(Real(rnd_r), Real(rnd_i)); + } + + return result; + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomUniformFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomUniformFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + typename RealType::Type max = Element(1), ///< upper bound of distribution + typename RealType::Type min = Element(0), ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + int exclude_zero = -1, ///< If non-negative, excludes zeros from tensor init + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, max, min, bits, pnan, exclude_zero); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + typename RealType::Type max, ///< upper bound of distribution + typename RealType::Type min, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomUniformFunc; + + typename RandomFunc::Params params(seed, max, min, bits, pnan); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random sparse meta +template ///< Element type +struct RandomSparseMetaFunc { + + using FloatType = float; + + using IntType = int32_t; + + /// Parameters structure + struct Params { + + // + // Data members + // + + uint64_t seed; + FloatType range; + int MetaSizeInBits; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), + MetaSizeInBits(MetaSizeInBits_) { + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + /// RNG state object + curandState_t rng_state; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + RandomSparseMetaFunc(Params const ¶ms): params(params) { + + uint64_t gtid = threadIdx.x + blockIdx.x * blockDim.x; + + curand_init(params.seed, gtid, 0, &rng_state); + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + Element operator()() { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element *MetaArray = + (params.MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + FloatType rnd = random_uniform_float(&rng_state); + rnd = params.range * rnd; + Element meta = MetaArray[(int)rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + using RandomFunc = RandomSparseMetaFunc; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + typename RandomFunc::Params random; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + typename RandomFunc::Params random_ = RandomFunc::Params() + ): + view(view_), random(random_) { + + } + }; + + // + // Data members + // + + Params params; + RandomFunc random; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillRandomSparseMetaFunc(Params const ¶ms): params(params), random(params.random) { + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + params.view.at(coord) = random(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView view, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + using Func = detail::TensorFillRandomUniformFunc; + using Params = typename Func::Params; + + typename RandomFunc::Params random(seed, MetaSizeInBits); + + TensorForEach( + view.extent(), + Params(view, random), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits = 2, ///< meta data size + cudaStream_t stream = nullptr) { + + using RandomFunc = detail::RandomSparseMetaFunc; + + typename RandomFunc::Params params(seed, MetaSizeInBits); + + BlockForEach(ptr, capacity, params, /*grid_size*/0, /*block_size*/0, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Functor to fill a tensor with zeros off the diagonal and a uniform value on the diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Updates the tensor + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + params.view.at(coord) = (is_diag ? params.diag : params.other); + } +}; + +// Overwrites the elements of a tensor with a uniform value depending on fill mode +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element element; + FillMode fill_mode; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params(): fill_mode(FillMode::kNone) { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, + Element element_, + FillMode fill_mode_ + ): + view(view_), element(element_), fill_mode(fill_mode_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorFillPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + case FillMode::kFull: + predicate = true; + break; + + case FillMode::kLower: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] < coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kUpper: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] > coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kDiagonal: + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i - 1] != coord[i]) { + predicate = false; + break; + } + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorClearPartialFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// + static_assert((Layout::kRank == 2), "TensorClearPartial is only supported for matrices"); + + /// Parameters structure + struct Params { + TensorView view{}; + Element element{}; + FillMode fill_mode{FillMode::kNone}; + int alignment{0}; + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorClearPartialFunc(Params const ¶ms): params(params) { + + } + + /// Overwrites the element if it is within the covered region. + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool predicate = true; + + switch (params.fill_mode) { + + case FillMode::kLower: + if ((coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kUpper: + if ((coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= params.alignment)) { + predicate = false; + break; + } + break; + + case FillMode::kNone: // fall-through + + default: + predicate = false; + break; + } + + if (predicate) { + params.view.at(coord) = params.element; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0), ///< value to write off the diagonal + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/// Fills a tensor partially depending on fill mode. Elements not covered by the fillmode are +/// not written. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + cudaStream_t stream = nullptr) { + + typedef detail::TensorFillPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, element, fill_mode), + stream + ); +} + +/// Clears a tensor partially depending on fill mode and alignment. Elements on the wrong-side +/// of fillmode (upto the alignment) are overwritten with the user supplied element (typically zeros) +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorClearPartial( + TensorView view, ///< destination tensor + Element element, + FillMode fill_mode, + int alignment, + cudaStream_t stream = nullptr) { + + typedef detail::TensorClearPartialFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params{view, element, fill_mode, alignment}, + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView view, ///< destination tensor + Element val = Element(0), ///< value to uniformly fill it with + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, val, val, stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView view, ///< destination tensor + cudaStream_t stream = nullptr) { + + TensorFillDiagonal(view, Element(1), Element(0), stream); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element diag; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element diag_ = Element(1) + ): + view(view_), diag(diag_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (is_diag) { + params.view.at(coord) = params.diag; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView view, ///< destination tensor + Element diag = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, diag), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element other; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorUpdateOffDiagonalFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + params.view.at(coord) = params.other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView view, ///< destination tensor + Element other = Element(1), + cudaStream_t stream = nullptr) { + + typedef detail::TensorUpdateOffDiagonalFunc Func; + typedef typename Func::Params Params; + + TensorForEach( + view.extent(), + Params(view, other), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorFillLinearFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element sum = params.s; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + if constexpr (is_complex::value) { + if constexpr (sizeof_bits::value <= 32) { + sum = Element(static_cast>(sum) + + static_cast>(params.v[i]) * static_cast>(coord[i])); + } + } + else if constexpr (sizeof_bits::value <= 32) { + if constexpr (std::numeric_limits::is_integer) { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + else { + sum = Element(static_cast(sum) + + static_cast(params.v[i]) * static_cast(coord[i])); + } + } + else { + sum += params.v[i] * coord[i]; + } + } + + params.view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView view, ///< destination tensor + Array const & v, + Element s = Element(0), + cudaStream_t stream = nullptr) { + + using Func = detail::TensorFillLinearFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, v, s), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr, + int exclude_zero = -1 ///< If non-negative, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. + ) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + exclude_zero, + stream); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + exclude_zero, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + + using Layout = layout::PackedVectorLayout; + Layout::TensorCoord size(static_cast(capacity)); // -Wconversion + Layout layout = Layout::packed(size); + TensorView view(ptr, layout, size); + + Array c{}; + c[0] = v; + + TensorFillLinear(view, c, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist, + cudaStream_t stream = nullptr) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + static_cast(dist.gaussian.mean), + static_cast(dist.gaussian.stddev), + dist.int_scale, + stream); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + static_cast(dist.uniform.max), + static_cast(dist.uniform.min), + dist.int_scale, + dist.uniform.pnan, + stream); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalInFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element const *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element const *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalInFunc(Params const ¶ms): params(params) { + + } + + /// Only update the diagonal element + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.view.at(coord) = params.ptr[coord[0]]; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView view, ///< destination tensor + Element const *ptr, ///< dense buffer of elements + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalInFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + + +namespace detail { + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorCopyDiagonalOutFunc { + + /// View type + using TensorView = TensorView; + + /// Scalar type + typedef typename TensorView::Element T; + + /// Coordinate in tensor's index space + typedef typename TensorView::TensorCoord TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element *ptr; + + /// Default ctor + CUTLASS_HOST_DEVICE + Params() { } + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + Params( + TensorView view_, ///< destination tensor + Element *ptr_ + ): + view(view_), ptr(ptr_) { + + } + }; + + // + // Data members + // + + /// Parameters object + Params params; + + // + // Methods + // + + /// Device-side initialization of RNG + CUTLASS_DEVICE + TensorCopyDiagonalOutFunc(Params const ¶ms): params(params) { + + } + + /// Compute random value and update RNG state + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + bool is_diagonal = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[0]) { + is_diagonal = false; + } + } + if (is_diagonal) { + params.ptr[coord[0]] = params.view.at(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView view, ///< source tensor + cudaStream_t stream = nullptr) { + + using Func = detail::TensorCopyDiagonalOutFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, ptr), + /*grid_size*/0, /*block_size*/0, + stream + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2dfd85c47b8c9450c348de32dccb7f1be9c3c1 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_foreach.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/util/reference/device/kernel/tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element in a tensor's index space. +template +struct TensorForEach { + + /// Constructor performs the operation. + TensorForEach( + Coord size, Params params = Params(), + int grid_size = 0, int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::TensorForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::TensorForEach<<< grid, block, 0, stream >>>(size, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Launches a kernel calling a functor for each element along a tensor's diagonal +template +struct TensorDiagonalForEach { + + /// Constructor performs the operation + TensorDiagonalForEach( + Coord size, Params params = Params(), + int start = 0, int end = -1, + int block_size = 128, cudaStream_t stream = nullptr) { + + if (end < 0) { + end = size.min(); + } + + dim3 block(block_size, 1, 1); + dim3 grid((end - start + block_size - 1) / block_size, 1, 1); + + kernel::TensorDiagonalForEach<<< grid, block, 0, stream >>>( + size, params, start, end); + } +}; + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params(), + int grid_size = 0, + int block_size = 0, + cudaStream_t stream = nullptr) { + + if (!grid_size || !block_size) { + + // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API + cudaError_t result = cudaOccupancyMaxPotentialBlockSize( + &grid_size, + &block_size, + reinterpret_cast(kernel::BlockForEach)); + + if (result != cudaSuccess) { + throw std::runtime_error("Failed to query occupancy."); + } + // Limit block size. This has the effect of increasing the number of items processed by a + // single thread and reduces the impact of initialization overhead. + block_size = (block_size < 128 ? block_size : 128); + } + + dim3 grid(grid_size, 1, 1); + dim3 block(block_size, 1, 1); + + kernel::BlockForEach<<< grid, block, 0, stream >>>(ptr, capacity, params); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..3e6d7b300f34fec6aec96e72f78427cf677936b4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_reduce.h @@ -0,0 +1,514 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/reference/detail/linear_to_coordinate.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + int64_t size = view.size(); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + + // Fetch element + Element x = view.at(coord); + + // Transform + identity = reduce(identity, transform(x)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp, + int kBlockSize = 128 +> +__global__ void TensorTransformReducePartial( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace) { /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + + int64_t idx = threadIdx.x + blockIdx.x * blockDim.x; + auto size = static_cast(view_A.size()); + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (; idx < size; idx += blockDim.x * gridDim.x) { + + // Map linear thread ID onto tensor coordinate + typename Layout::TensorCoord coord; + + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + + // Fetch element + Element a = view_A.at(coord); + Element b = view_B.at(coord); + + // Transform + identity = reduce(identity, transform(a, b)); + } + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + // One thread performs the final reduction and stores out. This could be enhanced via + // a tree reduction and pipelining. + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[blockIdx.x] = identity; + } +} + + +template < + typename ComputeType, + typename ReduceOp, + int kBlockSize = 32 +> +__global__ void TensorTransformReduceFinalize( + ComputeType *workspace, + ComputeType identity, + int workspace_size, + ReduceOp reduce) { + + __shared__ ComputeType scratchpad[kBlockSize]; + + for (int idx = threadIdx.x; idx < workspace_size; idx += kBlockSize) { + identity = reduce(identity, workspace[idx]); + } + + scratchpad[threadIdx.x] = identity; + + __syncthreads(); + + if (threadIdx.x == 0) { + + for (int i = 1; i < kBlockSize; ++i) { + identity = reduce(identity, scratchpad[i]); + } + + workspace[0] = identity; + } +} + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of two tensors, zipped together +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, /// View of the tensor to reduce over + TensorView view_B, /// View of the tensor to reduce over + ComputeType identity, /// Identity element of the reduction operation + ReduceOp reduce, /// Reduces an accumulated value with a transformed element: f(ComputeType, ComputeType) => ComputeType + TransformOp transform, /// Transforms the tensor element to ComputeType: g(Element) => ComputeType + ComputeType *workspace, /// Device-side workspace for accumulating partial results. The reduced element is stored in workspace[0] + int workspace_size, /// Number of elements in workspace + cudaStream_t stream = nullptr, /// CUDA stream to launch into + bool copy_out = true /// If true, the value of workspace[0] is copied to host and returned. Otherwise, `identity` is returned. +) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Extents must be equal."); + } + + int const kBlockSize = 128; + + dim3 block(kBlockSize, 1); + dim3 grid(workspace_size, 1); + + kernel::TensorTransformReducePartial< + Element, Layout, ComputeType, ReduceOp, TransformOp, kBlockSize + ><<< grid, block, 0, stream >>>( + view_A, view_B, identity, reduce, transform, workspace + ); + + int const kFinalizeBlockSize = 32; + + kernel::TensorTransformReduceFinalize< + ComputeType, ReduceOp, kFinalizeBlockSize + ><<< dim3(1, 1), dim3(kFinalizeBlockSize, 1), 0, stream >>>( + workspace, identity, workspace_size, reduce + ); + + cudaStreamSynchronize(stream); + + if (copy_out) { + cudaError_t result = cudaMemcpy(&identity, workspace, sizeof(identity), cudaMemcpyDeviceToHost); + if (result != cudaSuccess) { + throw std::runtime_error("cudaMemcpy() failed"); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform, + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + // Optionally query for the SM count to size the workspace. + if (!workspace_size) { + + int device_idx = 0; + cudaDeviceProp prop; + + cudaError_t result = cudaGetDevice(&device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() failed"); + } + + result = cudaGetDeviceProperties(&prop, device_idx); + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProp() failed"); + } + + workspace_size = int(prop.multiProcessorCount); + } + + DeviceAllocation workspace(workspace_size); + + ComputeType output = TensorTransformReduce( + view_A, + view_B, + identity, + reduce, + transform, + workspace.get(), + workspace_size, + stream, + true); + + return output; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform, stream, workspace_size); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSq(view, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform, stream, workspace_size); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType(), + cudaStream_t stream = nullptr, + int workspace_size = 0 +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity, stream, workspace_size)); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h new file mode 100644 index 0000000000000000000000000000000000000000..0e3d99ddf845810249f909fbdee4505a0a732c4f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/tensor_relu.h @@ -0,0 +1,141 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/util/reference/device/tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorReLuFunc { + + /// View type + using TensorView = TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view; + Element threshold; + + + // + // Methods + // + + Params( + TensorView view_ = TensorView(), + Element threshold_ = Element(0) + ): + view(view_), threshold(threshold_) { + + } + }; + + // + // Data members + // + + Params params; + + // + // Methods + // + + CUTLASS_DEVICE + TensorReLuFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + + Element const & value = params.view.at(coord); + params.view.at(coord) = (value < params.threshold) ? params.threshold : value; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Apply ReLu on a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorReLu( + TensorView view, ///< destination tensor + Element threshold = Element(0)) { ///< ReLu threshold + + using Func = detail::TensorReLuFunc; + using Params = typename Func::Params; + + TensorForEach( + view.extent(), + Params(view, threshold) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..dd11f96bd92f6995590e61665e41a3e830bceacd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/device/thread/gemm.h @@ -0,0 +1,186 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace device { +namespace thread { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Thread-level blocked general matrix product. +// +// Note, this is a reference implementation. Performance is not expected to approach peak. +// +template < + typename TensorRefA, + typename TensorRefB, + typename TensorRefC, + typename ScalarType, + typename AccumulatorType, + typename OutputTile, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +struct Gemm { + + using ElementA = typename TensorRefA::Element; + using ElementB = typename TensorRefB::Element; + using ElementC = typename TensorRefC::Element; + + // + // Data members + // + + /// Tile for A operand + ElementA A_tile[OutputTile::kColumn]; + + /// Tile for B operand + ElementB B_tile[OutputTile::kRow]; + + /// Tile for Accumulator + AccumulatorType accum[OutputTile::kColumn][OutputTile::kRow]; + + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + Gemm(AccumulatorType initial_accum = AccumulatorType(0)) { + + // Clear fetch registers + for (int i = 0; i < OutputTile::kColumn; ++i) { + A_tile[i] = ElementA(0); + } + + for (int j = 0; j < OutputTile::kRow; ++j) { + B_tile[j] = ElementB(0); + } + + // Clear accumulators + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kColumn; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kRow; ++i) { + accum[j][i] = initial_accum; + } + } + } + + /// Computes a matrix product + CUTLASS_HOST_DEVICE + Gemm & multiply_add( + gemm::GemmCoord problem_size, + TensorRefA tensor_a, + TensorRefB tensor_b, + MatrixCoord output_coord = MatrixCoord()) { + + InnerProductOp inner_product_op; + + // Loop over the GEMM K dimension + CUTLASS_PRAGMA_NO_UNROLL + for (int k = 0; k < problem_size.k(); ++k) { + + // Fetch a slice of the A matrix + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + if (output_coord.row() + i < problem_size.m()) { + A_tile[i] = tensor_a.at(make_Coord(output_coord.row() + i, k)); + } + } + + // Fetch a slice of the B matrix + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + if (output_coord.column() + j < problem_size.n()) { + B_tile[j] = tensor_b.at(make_Coord(k, output_coord.column() + j)); + } + } + + // Compute an accumulated matrix product + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < OutputTile::kRow; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < OutputTile::kColumn; ++i) { + accum[j][i] = inner_product_op(A_tile[i], B_tile[j], accum[j][i]); + } + } + } + + return *this; + } + + /// Performs linear scaling of matrix product and updates output tensor + CUTLASS_HOST_DEVICE + Gemm & epilogue( + gemm::GemmCoord problem_size, + ScalarType alpha, + ScalarType beta, + TensorRefC tensor_c, + TensorRefC tensor_d, + MatrixCoord output_coord = MatrixCoord()) { + + ConvertOp convert_op; + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[j][i]) + + beta * ScalarType(tensor_c.at(coord)) + ); + } + } + } + + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57443325629ea4e5d855fe18f94c73b10a71a73a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/conv.hpp @@ -0,0 +1,782 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for CONV in host-side code. +*/ +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t d_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<5>(activation)) && + (n_ >= 0 && n_ < size<4>(activation)) && + (d_ >= 0 && d_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t h_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<4>(activation)) && + (n_ >= 0 && n_ < size<3>(activation)) && + (h_ >= 0 && h_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +template +bool +is_activation_in_bounds( + cute::Tensor const& activation, + int32_t n_, int32_t w_, int32_t c_, int32_t g_) { + return ((g_ >= 0 && g_ < size<3>(activation)) && + (n_ >= 0 && n_ < size<2>(activation)) && + (w_ >= 0 && w_ < size<1>(activation)) && + (c_ >= 0 && c_ < size<0>(activation))); +} + +} // namespace detail + +template< + class ElementAcc_, + class ElementScalar_, + class ElementCompute_, + class ElementC_, + class ElementOut_, + bool ResidualAdd_, + class TensorAlpha_, + class TensorBeta_, + class TensorBias_, + class ActivationFunctor_ = cutlass::epilogue::thread::Identity +> +struct ConvEpilogueFusionParams { + using ElementAcc = ElementAcc_; + using ElementScalar = ElementScalar_; + using ElementCompute = ElementCompute_; + using ElementC = ElementC_; + using ElementOut = ElementOut_; + using TensorAlpha = TensorAlpha_; + using TensorBeta = TensorBeta_; + using TensorBias = TensorBias_; + using ActivationFunctor = ActivationFunctor_; + static constexpr bool ResidualAdd = ResidualAdd_; // Source added after activation + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorAlpha tensor_alpha{}; + TensorBeta tensor_beta{}; + TensorBias tensor_bias{}; +}; + +template< + cutlass::conv::Operator ConvOp, + int NumSpatialDims, + class TensorA, + class TensorB, + class TensorC, + class TensorD, + class ShapePadding, + class StrideTraversal, + class ShapeDilation, + class EpilogueFusionParams +> +struct ConvReferenceImpl { + // Hard code accumlulator type to float to avoid data lost in accumulating add. + using ElementAcc = cutlass::platform::conditional_t, double, float>; + using ElementC = typename EpilogueFusionParams::ElementC; + using ElementOut = typename EpilogueFusionParams::ElementOut; + using ElementScalar = typename EpilogueFusionParams::ElementScalar; + using ElementCompute = typename EpilogueFusionParams::ElementCompute; + using ElementBias = typename EpilogueFusionParams::TensorBias::value_type; + using ActivationFunctor = typename EpilogueFusionParams::ActivationFunctor; + + // Input related converter + NumericConverter acc_converter; + NumericConverter residual_converter; + NumericConverter bias_converter; + // Scale related converter + NumericConverter scale_converter; + // Output related converter + NumericConverter output_converter; + + EpilogueFusionParams& epi_fusion_params_; + TensorA const& tensor_a_; + TensorB const& tensor_b_; + TensorC const& tensor_c_; + TensorD& tensor_d_; + + ShapePadding const& padding_; + StrideTraversal const& tstride_; + ShapeDilation const& dilation_; + + // Epilogue activation operation + ActivationFunctor epi_activation; + + ConvReferenceImpl( + TensorA const& tensor_a, + TensorB const& tensor_b, + TensorC const& tensor_c, + TensorD& tensor_d, + ShapePadding const& padding, + StrideTraversal const& tstride, + ShapeDilation const& dilation, + EpilogueFusionParams& epi_fusion_params) + : tensor_a_(tensor_a), + tensor_b_(tensor_b), + tensor_c_(tensor_c), + tensor_d_(tensor_d), + padding_(padding), + tstride_(tstride), + dilation_(dilation), + epi_fusion_params_(epi_fusion_params) + { + static_assert(rank(ShapePadding{}) == rank(ShapeDilation{})); + static_assert(rank(ShapePadding{}) == rank(StrideTraversal{})); + } + + void compute_reference() { + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + fprop_reference(cute::Int{}); + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + dgrad_reference(cute::Int{}); + } + else { + wgrad_reference(cute::Int{}); + } + } + +private: + // Specialization for 1D fprop kernel + void fprop_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, w, c, g)) { + auto a = tensor_a_(c, w, n, g); + auto b = tensor_b_(c, s, k, g); + accumulator += ElementAcc(a * b); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, n, g)); + } + tensor_d_(k, q, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D fprop kernel + void fprop_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, h, w, c, g)) { + auto a = tensor_a_(c, w, h, n, g); + auto b = tensor_b_(c, s, r, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, n, g)); + } + tensor_d_(k, q, p, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D fprop kernel + void fprop_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t Z = size<3>(tensor_d_); + int32_t P = size<2>(tensor_d_); + int32_t Q = size<1>(tensor_d_); + int32_t K = size<0>(tensor_d_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + int32_t C = size<0>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + for (int32_t k = 0; k < K; ++k) { + auto accumulator = ElementAcc(0); + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + if (detail::is_activation_in_bounds(tensor_a_, n, d, h, w, c, g)) { + auto a = tensor_a_(c, w, h, d, n, g); + auto b = tensor_b_(c, s, r, t, k, g); + accumulator += ElementAcc(a * b); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[k] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[k] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[k]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(k, q, p, z, n, g)); + } + tensor_d_(k, q, p, z, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D dgrad kernel + void dgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, n, g) * tensor_b_(c, s, k, g)); + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, n, g)); + } + tensor_d_(c, w, n, g) = output_converter(output); + } + } + } + } + + } + + // Specialization for 2D dgrad kernel + void dgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, n, g) * tensor_b_(c, s, r, k, g)); + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, n, g)); + } + + tensor_d_(c, w, h, n, g) = output_converter(output); + } + } + } + } + } + + } + + // Specialization for 3D dgrad kernel + void dgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = size<4>(tensor_d_); + int32_t D = size<3>(tensor_d_); + int32_t H = size<2>(tensor_d_); + int32_t W = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + int32_t K = size<4>(tensor_b_); + int32_t T = size<3>(tensor_b_); + int32_t R = size<2>(tensor_b_); + int32_t S = size<1>(tensor_b_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t n = 0; n < N; ++n) { + for (int32_t d = 0; d < D; ++d) { + for (int32_t h = 0; h < H; ++h) { + for (int32_t w = 0; w < W; ++w) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + int32_t q = w + cute::get<0>(padding_) - s * cute::get<0>(dilation_); + int32_t p = h + cute::get<1>(padding_) - r * cute::get<1>(dilation_); + int32_t z = d + cute::get<2>(padding_) - t * cute::get<2>(dilation_); + + if (q % cute::get<0>(tstride_) == 0) { + q /= cute::get<0>(tstride_); + } else { + continue; + } + + if (p % cute::get<1>(tstride_) == 0) { + p /= cute::get<1>(tstride_); + } else { + continue; + } + + if (z % cute::get<2>(tstride_) == 0) { + z /= cute::get<2>(tstride_); + } else { + continue; + } + + if (detail::is_activation_in_bounds(tensor_a_, n, z, p, q, k, g)) { + accumulator += ElementAcc(tensor_a_(k, q, p, z, n, g) * tensor_b_(c, s, r, t, k, g)); + } + } + } + } + } + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) + ? epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) + ? epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, w, h, d, n, g)); + } + tensor_d_(c, w, h, d, n, g) = output_converter(output); + } + } + } + } + } + } + + } + + // Specialization for 1D wgrad kernel + void wgrad_reference(cute::Int<1> spatial_dims) { + int32_t G = size<3>(tensor_d_); + int32_t N = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(2) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, n, g); + auto xformed_act = + tensor_a_(k, q, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, k, g)); + } + tensor_d_(c, s, k, g) = output_converter(output); + } + } + } + } + } + + // Specialization for 2D wgrad kernel + void wgrad_reference(cute::Int<2> spatial_dims) { + int32_t G = size<4>(tensor_d_); + int32_t N = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, n, g); + auto xformed_act = + tensor_a_(k, q, p, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, k, g)); + } + tensor_d_(c, s, r, k, g) = output_converter(output); + } + } + } + } + } + } + + // Specialization for 3D wgrad kernel + void wgrad_reference(cute::Int<3> spatial_dims) { + int32_t G = size<5>(tensor_d_); + int32_t N = + size<4>(tensor_a_); + int32_t Z = + size<3>(tensor_a_); + int32_t P = + size<2>(tensor_a_); + int32_t Q = + size<1>(tensor_a_); + int32_t K = + size<0>(tensor_a_); + int32_t T = size<3>(tensor_d_); + int32_t R = size<2>(tensor_d_); + int32_t S = size<1>(tensor_d_); + int32_t C = size<0>(tensor_d_); + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int32_t g = 0 ; g < G; ++g) { + for (int32_t k = 0; k < K; ++k) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t r = 0; r < R; ++r) { + for (int32_t s = 0; s < S; ++s) { + for (int32_t c = 0; c < C; ++c) { + auto accumulator = ElementAcc(0); + for (int32_t n = 0; n < N; ++n) { + for (int32_t z = 0; z < Z; ++z) { + for (int32_t p = 0; p < P; ++p) { + for (int32_t q = 0; q < Q; ++q) { + int32_t w = q * cute::get<0>(tstride_) - cute::get<0>(padding_) + s * cute::get<0>(dilation_); + int32_t h = p * cute::get<1>(tstride_) - cute::get<1>(padding_) + r * cute::get<1>(dilation_); + int32_t d = z * cute::get<2>(tstride_) - cute::get<2>(padding_) + t * cute::get<2>(dilation_); + bool is_in_bounds = + detail::is_activation_in_bounds(tensor_b_, n, d, h, w, c, g); + if (is_in_bounds) { + auto act = + tensor_b_(c, w, h, d, n, g); + auto xformed_act = + tensor_a_(k, q, p, z, n, g); + accumulator += ElementAcc(act * xformed_act); + } + } + } + } + } + + ElementScalar alpha = raw_pointer_cast(epi_fusion_params_.tensor_alpha.data()) ? + epi_fusion_params_.tensor_alpha[c] : epi_fusion_params_.alpha; + ElementScalar beta = raw_pointer_cast(epi_fusion_params_.tensor_beta.data()) ? + epi_fusion_params_.tensor_beta[c] : epi_fusion_params_.beta; + + ElementCompute output = scale_converter(alpha) * acc_converter(accumulator); + if (not EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + if (raw_pointer_cast(epi_fusion_params_.tensor_bias.data())) { + output += bias_converter(epi_fusion_params_.tensor_bias[c]); + } + output = epi_activation(output); + if (EpilogueFusionParams::ResidualAdd) { + output += scale_converter(beta) * residual_converter(tensor_c_(c, s, r, t, k, g)); + } + tensor_d_(c, s, r, t, k, g) = output_converter(output); + } + } + } + } + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..73298e5794f0f2658ef18fb3f46466c400fc831e --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/convolution.h @@ -0,0 +1,802 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Reference implementation for convolution in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Forward propagation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv2d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dFprop( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + int group_idx = k / (problem_size.K / problem_size.groups); + int channels_per_group = problem_size.C / problem_size.groups; + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < channels_per_group; ++c) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (h >= 0 && h < problem_size.H && w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, h, w, c + group_idx * channels_per_group}); + ElementB b = tensor_w.at({k, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +/// Depthwise-separable convolution +template , + typename InnerProductOp = multiply_add> +void Depsep_Fprop(cutlass::TensorView tensor_A, + cutlass::TensorView tensor_B, + cutlass::TensorView tensor_C, + cutlass::TensorView tensor_D, + ElementCompute alpha, + ElementCompute beta, + cutlass::Tensor4DCoord padding = cutlass::Tensor4DCoord(), + cutlass::Coord<2> conv_stride = cutlass::Coord<2>(), + cutlass::Coord<2> dilation = cutlass::Coord<2>(), + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < tensor_C.extent().n(); ++n) { + for (int p = 0; p < tensor_C.extent().h(); ++p) { + for (int q = 0; q < tensor_C.extent().w(); ++q) { + for (int g = 0; g < tensor_C.extent().c(); ++g) { + ElementAccumulator acc = ElementAccumulator(); + for (int r = 0; r < tensor_B.extent().h(); ++r) { + for (int s = 0; s < tensor_B.extent().w(); ++s) { + + // input activation H and W + int h = p * conv_stride[0] - padding[0] + r * dilation[0]; + int w = q * conv_stride[1] - padding[2] + s * dilation[1]; + + if (h < tensor_A.extent().h() && h >= 0 && w < tensor_A.extent().w() && w >= 0) { + ElementA a = tensor_A.at(cutlass::make_Coord(n, h, w, g)); + + ElementB b = (mode == cutlass::conv::Mode::kCrossCorrelation) + ? tensor_B.at(cutlass::make_Coord(g, r, s, 0)) + : tensor_B.at(cutlass::make_Coord( + g, tensor_B.extent().h() - r - 1, tensor_B.extent().w() - s - 1, 0)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = tensor_C.at(cutlass::make_Coord(n, p, q, g)); + tensor_D.at(cutlass::make_Coord(n, p, q, g)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dDgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; +#if 0 + std::cout << "row:" + << n * problem_size.H * problem_size.W + + h * problem_size.W + + w << " " + << "n, p, q: (" + << n << ", " + << p << ", " + << q << ") * " + << "r, s: (" + << r << ", " + << s << ") [" + << ((p < problem_size.P && q < problem_size.Q) ? "true":"false") << "]" + << std::endl; +#endif + if (p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, r, s, c)); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2dWgrad( + cutlass::conv::Conv2dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + cutlass::Tensor4DCoord b_coord; + + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + b_coord = make_Coord( + n, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + acc = inner_product_op(a, b, acc); + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (K) +} + +/// Generic 2D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv2d( + conv::Operator convolutional_operator, + conv::Conv2dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv2dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv2dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ElementD, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// 3D convolution +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// y = conv3d(x, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dFprop( + conv::Conv3dProblemSize problem_size, + TensorRef tensor_x, + TensorRef tensor_w, + TensorRef tensor_y_in, + TensorRef tensor_y_out, + ElementCompute alpha, + ElementCompute beta) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + for (int k = 0; k < problem_size.K; ++k) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int d = z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d; + int h = p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h; + int w = q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w; + + if (d >= 0 && d < problem_size.D && + h >=0 && h < problem_size.H && + w >= 0 && w < problem_size.W) { + + ElementA a = tensor_x.at({n, d, h, w, c}); + ElementB b = tensor_w.at({k, t, r, s, c}); + + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_y_in.at(cutlass::make_Coord(n, z, p, q, k)); + } + + tensor_y_out.at(cutlass::make_Coord(n, z, p, q, k)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Dgrad / Deconv +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dx = dgrad(dy, w) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dDgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_w, + TensorRef tensor_dx_in, + TensorRef tensor_dx_out, + ElementCompute alpha, + ElementCompute beta, + bool is_deconv = false) { + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + // Apply MMA and accumulate ElementAccumulator + for (int n = 0; n < problem_size.N; ++n) { + for (int d = 0; d < problem_size.D; ++d) { + for (int h = 0; h < problem_size.H; ++h) { + for (int w = 0; w < problem_size.W; ++w) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int k = 0; k < problem_size.K; ++k) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + int z = d + problem_size.pad_d - filter_t * problem_size.dilation_d; + int p = h + problem_size.pad_h - filter_r * problem_size.dilation_h; + int q = w + problem_size.pad_w - filter_s * problem_size.dilation_w; + + if (z >= 0 && (z % problem_size.stride_d) == 0 && + p >= 0 && (p % problem_size.stride_h) == 0 && + q >= 0 && (q % problem_size.stride_w) == 0) { + + z = z / problem_size.stride_d; + p = p / problem_size.stride_h; + q = q / problem_size.stride_w; + + if (z < problem_size.Z && p < problem_size.P && q < problem_size.Q) { + + ElementA a = tensor_dy.at(cutlass::make_Coord(n, z, p, q, k)); + ElementB b = is_deconv ? tensor_w.at(cutlass::make_Coord(c, t, r, s, k)) + : tensor_w.at(cutlass::make_Coord(k, t, r, s, c)); + acc = inner_product_op(ElementAccumulator(a), ElementAccumulator(b), acc); + } + } + + } // for (K) + } // for (S) + } // for (R) + } // for (T) + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dx_in.at(cutlass::make_Coord(n, d, h, w, c)); + } + + tensor_dx_out.at(cutlass::make_Coord(n, d, h, w, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (W) + } // for (H) + } // for (D) + } // for (N) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Wgrad +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// dw = wgrad(dy, x) +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3dWgrad( + cutlass::conv::Conv3dProblemSize problem_size, + TensorRef tensor_dy, + TensorRef tensor_x, + TensorRef tensor_dw_in, + TensorRef tensor_dw_out, + ElementCompute alpha, + ElementCompute beta) { + + InnerProductOp inner_product_op; + ConvertOp convert_op; + + // Apply MMA and accumulate ElementAccumulator + for (int k = 0; k < problem_size.K; ++k) { + for (int t = 0; t < problem_size.T; ++t) { + for (int r = 0; r < problem_size.R; ++r) { + for (int s = 0; s < problem_size.S; ++s) { + for (int c = 0; c < problem_size.C; ++c) { + + ElementAccumulator acc = ElementAccumulator(); + + for (int n = 0; n < problem_size.N; ++n) { + for (int z = 0; z < problem_size.Z; ++z) { + for (int p = 0; p < problem_size.P; ++p) { + for (int q = 0; q < problem_size.Q; ++q) { + + int filter_t = t; + int filter_r = r; + int filter_s = s; + + if (problem_size.mode == cutlass::conv::Mode::kConvolution) { + filter_t = problem_size.T - 1 - t; + filter_r = problem_size.R - 1 - r; + filter_s = problem_size.S - 1 - s; + } + + Tensor5DCoord b_coord = make_Coord( + n, + z * problem_size.stride_d - problem_size.pad_d + filter_t * problem_size.dilation_d, + p * problem_size.stride_h - problem_size.pad_h + filter_r * problem_size.dilation_h, + q * problem_size.stride_w - problem_size.pad_w + filter_s * problem_size.dilation_w, + c); + + if (b_coord.d() < problem_size.D && b_coord.d() >= 0 && + b_coord.h() < problem_size.H && b_coord.h() >= 0 && + b_coord.w() < problem_size.W && b_coord.w() >= 0) { + + ElementAccumulator a = ElementAccumulator(tensor_dy.at(cutlass::make_Coord(n, z, p, q, k))); + ElementAccumulator b = ElementAccumulator(tensor_x.at(b_coord)); + + acc = inner_product_op(a, b, acc); + } + } + } + } + } + + // Apply Epilogue, compute ElementCompute, convert and store ElementC + ElementC c_ref = ElementC(); + + if (beta != ElementCompute()) { + c_ref = tensor_dw_in.at(cutlass::make_Coord(k, t, r, s, c)); + } + + tensor_dw_out.at(cutlass::make_Coord(k, t, r, s, c)) = + convert_op(alpha * ElementCompute(acc) + beta * ElementCompute(c_ref)); + + } // for (C) + } // for (S) + } // for (R) + } // for (T) + } // for (K) +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic 3D convolution targeting Conv2dFprop, Conv2dDgrad, and Conv2dWgrad. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementCompute, + typename ElementAccumulator = ElementCompute, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Conv3d( + conv::Operator convolutional_operator, + conv::Conv3dProblemSize problem_size, + TensorRef tensor_A, + TensorRef tensor_B, + TensorRef tensor_C, + TensorRef tensor_D, + ElementCompute alpha, + ElementCompute beta) { + + switch (convolutional_operator) { + case conv::Operator::kFprop: + Conv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + case conv::Operator::kDeconv: + case conv::Operator::kDgrad: + Conv3dDgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, (convolutional_operator == conv::Operator::kDeconv)); + break; + + case conv::Operator::kWgrad: + Conv3dWgrad< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, + ElementAccumulator, + ConvertOp, InnerProductOp + >(problem_size, tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta); + break; + + default: + break; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h new file mode 100644 index 0000000000000000000000000000000000000000..12ead83354b785096e8029b49f1ac353d5ce5f82 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/error_metrics.h @@ -0,0 +1,66 @@ + +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Helper to compute the relative error metric for tensor A_computed w.r.t. to tensor A_reference +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorRelativeErrorMetric( + TensorView view_A_computed, + TensorView view_B_reference, + ComputeType identity = ComputeType() +) { + + return cutlass::reference::host::TensorNormDiff(view_A_computed, view_B_reference, identity) / + cutlass::reference::host::TensorNorm(view_B_reference, identity); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2afee7b36d9822cc196f0f167f9dbec4c295d1a6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm.h @@ -0,0 +1,531 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +namespace cutlass { +namespace reference { +namespace host { + +template +struct CastIfScalar { + static Out cast(In in) { + return Out(in); + } +}; + +template +struct CastIfScalar, In> { + typedef cutlass::complex Out; + static Out cast(In in) { + return Out(static_cast(in)); + } +}; + +template +struct CastIfScalar, cutlass::complex> { + typedef cutlass::complex Out; + typedef cutlass::complex In; + static Out cast(In in) { + return Out(in); + } +}; + +template +Out cast_if_scalar(In in) { + return CastIfScalar::cast(in); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_gemm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_gemm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Gemm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add-saturate +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm, + NumericConverterClamp>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for XOR-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +/// Partial specialization for AND-popc +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Gemm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_gemm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Batched GEMM +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a batch of GEMMs over a set of matrices of common dimension. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c, + AccumulatorType initial_accum) { + + typename TensorRefCollectionA::ConstIterator tensor_a_it = tensor_a.begin(); + typename TensorRefCollectionB::ConstIterator tensor_b_it = tensor_b.begin(); + typename TensorRefCollectionC::ConstIterator tensor_c_it = tensor_c.begin(); + + for (int batch = 0; + batch < batch_count; + ++batch, ++tensor_a_it, ++tensor_b_it, ++tensor_c_it) { + + Gemm + gemm; + + gemm(problem_size, alpha, *tensor_a_it, *tensor_b_it, beta, *tensor_c_it, + initial_accum); + } +} + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +// +// TensorRefCollection* is a type satisfying the TensorRefCollection concept. +// +template < + typename TensorRefCollectionA, + typename TensorRefCollectionB, + typename TensorRefCollectionC, + typename ScalarType, + typename AccumulatorType +> +void BatchedGemm( + gemm::GemmCoord problem_size, + int batch_count, + ScalarType alpha, + TensorRefCollectionA const& tensor_a, + TensorRefCollectionB const& tensor_b, + ScalarType beta, + TensorRefCollectionC &tensor_c) { + + BatchedGemm(problem_size, batch_count, alpha, tensor_a, tensor_b, beta, tensor_c, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..221a6040854a74ce465af7b021bbbfae9b96a90b --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_complex.h @@ -0,0 +1,210 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/matrix_coord.h" + +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ElementD = ElementC, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b = tensor_b.at(MatrixCoord(k_block, col)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + + if (transform_b == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ElementD = ElementC +> +void GemmComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d) { + + GemmComplex(problem_size, alpha, tensor_a, transform_a, tensor_b, transform_b, beta, tensor_c, tensor_d, ScalarType(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..507c37d9eb5a8c998f1075d547e8430b2edc5685 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gemm_planar_complex.h @@ -0,0 +1,228 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued GEMM in host-side code. +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_types.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_ref_planar_complex.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add> +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d, + complex initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + using ComplexA = typename TensorRefPlanarComplex::ComplexElement; + using ComplexB = typename TensorRefPlanarComplex::ComplexElement; + using ComplexC = typename TensorRefPlanarComplex::ComplexElement; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + complex accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + + ComplexA a_ik = tensor_a.at(MatrixCoord(row, k_block)); + ComplexB b_kj = tensor_b.at(MatrixCoord(k_block, col)); + + complex a = complex{ + ComputeType(a_ik.real()), + ComputeType(a_ik.imag()) + }; + + complex b = complex{ + ComputeType(b_kj.real()), + ComputeType(b_kj.imag()) + }; + + if (transform_a == ComplexTransform::kConjugate) { + a = conj(a); + } + + if (transform_b == ComplexTransform::kConjugate) { + b = conj(b); + } + + accum[i][j] = inner_product_op(a, b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + complex acc{ + ScalarType(accum[i][j].real()), + ScalarType(accum[i][j].imag()) + }; + + ComplexC d_ij = tensor_c.at(coord); + + complex src{ + ScalarType(d_ij.real()), + ScalarType(d_ij.imag()) + }; + + complex result = alpha * acc + beta * src; + + d_ij.real() = convert_op(result.real()); + d_ij.imag() = convert_op(result.imag()); + + tensor_d.at(coord) = d_ij; + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void GemmPlanarComplex( + gemm::GemmCoord problem_size, + complex alpha, + TensorRefPlanarComplex tensor_a, + ComplexTransform transform_a, + TensorRefPlanarComplex tensor_b, + ComplexTransform transform_b, + complex beta, + TensorRefPlanarComplex tensor_c, + TensorRefPlanarComplex tensor_d) { + + GemmPlanarComplex( + problem_size, + alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, + tensor_c, + tensor_d, + complex()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd54dc6e378d0d0f0549ec922da8357841ac558f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -0,0 +1,916 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for GETT in host-side code. +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// +#include "cutlass/gemm/gemm.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/relatively_equal.h" + +#include "cute/tensor.hpp" +#include "cute/pointer.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::reference::host { + +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorB_ // (N, K, L) + + , class TensorSfA_ = TensorA_, + class TensorSfB_ = TensorB_ + +> +struct GettMainloopParams { + using ElementAccumulator = ElementAccumulator_; + using TensorA = TensorA_; + using TensorB = TensorB_; + using EngineA = typename TensorA::engine_type; + using LayoutA = typename TensorA::layout_type; + using EngineB = typename TensorB::engine_type; + using LayoutB = typename TensorB::layout_type; + + TensorA A{}; + TensorB B{}; + + ComplexTransform transform_A = ComplexTransform::kNone; + ComplexTransform transform_B = ComplexTransform::kNone; + + + using TensorSfA = TensorSfA_; + using TensorSfB = TensorSfB_; + using EngineSfA = typename TensorSfA::engine_type; + using LayoutSfA = typename TensorSfA::layout_type; + using EngineSfB = typename TensorSfB::engine_type; + using LayoutSfB = typename TensorSfB::layout_type; + TensorSfA_ SfA{}; + TensorSfB_ SfB{}; + + + GettMainloopParams() {} + + GettMainloopParams(TensorA tensor_A, TensorB tensor_B) + : A(tensor_A), B(tensor_B) {} + + + GettMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : A(tensor_A), SfA(tensor_SfA), + B(tensor_B), SfB(tensor_SfB) {} + + +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Mainloop Parameter Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementAccumulator_, + class TensorA_, // (M, K, L) + class TensorSfA_, // (M, K, L) + class TensorB_, // (N, K, L) + class TensorSfB_ // (N, K, L) +> +struct GettBlockScalingMainloopParams : public GettMainloopParams { + using Base = GettMainloopParams; + using ElementAccumulator = typename Base::ElementAccumulator; + using TensorA = typename Base::TensorA; + using TensorB = typename Base::TensorB; + using EngineA = typename Base::EngineA; + using LayoutA = typename Base::LayoutA; + using EngineB = typename Base::EngineB; + using LayoutB = typename Base::LayoutB; + ComplexTransform transform_A = Base::transform_A; + ComplexTransform transform_B = Base::transform_B; + + using TensorSfA = typename Base::TensorSfA; + using TensorSfB = typename Base::TensorSfB; + using EngineSfA = typename Base::EngineSfA; + using LayoutSfA = typename Base::LayoutSfA; + using EngineSfB = typename Base::EngineSfB; + using LayoutSfB = typename Base::LayoutSfB; + + GettBlockScalingMainloopParams() {} + + GettBlockScalingMainloopParams(TensorA tensor_A, TensorSfA tensor_SfA, TensorB tensor_B, TensorSfB tensor_SfB) + : Base(tensor_A, tensor_SfA, tensor_B, tensor_SfB) {} + + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class SfStrategy { + None = 0, + SfDGen = 1 +}; + + +/////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters +// +/////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementScalingFactor_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, // (M, N, L) + class TensorD_, // (M, N, L) + class VectorBias_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class TensorAux_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, N, L) + class VectorAlpha_ = decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // (M, 1) + class VectorBeta_ = VectorAlpha_, // (M, 1) + class ActivationFunctor_ = cutlass::epilogue::thread::Identity, + class TensorSFD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + class BiasBinaryOp_ = cutlass::plus, + bool PerColumnBias_ = false + , + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettEpilogueParams { + using ElementScalar = ElementScalar_; + using ElementScalingFactor = ElementScalingFactor_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using TensorC = TensorC_; + using TensorD = TensorD_; + using TensorAux = TensorAux_; + using VectorBias = VectorBias_; + using VectorAlpha = VectorAlpha_; + using VectorBeta = VectorBeta_; + using TensorSFD = TensorSFD_; + using SFD_VectorSize = SFD_VectorSize_; + using ActivationFunctor = ActivationFunctor_; + using BiasBinaryOp = BiasBinaryOp_; + + using EngineC = typename TensorC::engine_type; + using LayoutC = typename TensorC::layout_type; + using EngineD = typename TensorD::engine_type; + using LayoutD = typename TensorD::layout_type; + using EngineSfD = typename TensorSFD::engine_type; + using LayoutSfD = typename TensorSFD::layout_type; + static constexpr bool PerColumnBias = PerColumnBias_; + static constexpr SfStrategy SfGenStrategy = SfGenStrategy_; + + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + + TensorC C{}; + TensorD D{}; + VectorBias Bias{}; + TensorAux Aux{}; + VectorAlpha Valpha{}; + VectorBeta Vbeta{}; + TensorSFD SfD{}; + ElementCompute st = ElementCompute(1); + + ElementAccumulator* abs_max_D = nullptr; + ElementAccumulator* abs_max_Aux = nullptr; + + ElementScalingFactor scale_a = ElementScalingFactor(1); + ElementScalingFactor scale_b = ElementScalingFactor(1); + ElementScalingFactor scale_c = ElementScalingFactor(1); + ElementScalingFactor scale_d = ElementScalingFactor(1); + ElementScalingFactor scale_aux = ElementScalingFactor(1); + + bool beta_per_channel_scaling = false; + GettEpilogueParams() {} + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D) {} + + + GettEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : alpha(alpha), beta(beta), C(tensor_C), D(tensor_D), SfD(tensor_SfD), st(epilogue_st) {} + + + GettEpilogueParams( + ElementScalar alpha, ElementScalar beta, + TensorC tensor_C, TensorD tensor_D, + VectorBias bias, TensorAux tensor_aux, + VectorAlpha vector_alpha, VectorBeta vector_beta) + : alpha(alpha), beta(beta), + C(tensor_C), D(tensor_D), + Bias(bias), Aux(tensor_aux), + Valpha(vector_alpha), Vbeta(vector_beta) {} +}; + + + +//////////////////////////////////////////////////////////////////////// +// +// Gett Epilogue Parameters Specialization for Block Scaled GEMM kernels +// +//////////////////////////////////////////////////////////////////////// + +template< + class ElementScalar_, + class ElementAccumulator_, + class ElementCompute_, + class TensorC_, + class TensorD_, + class TensorSfD_ = TensorD_, + class SFD_VectorSize_ = cute::Int<0>, + SfStrategy SfGenStrategy_ = SfStrategy::None +> +struct GettBlockScalingEpilogueParams : public GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // class BiasBinaryOp_ = + false, //PerColumnBias_ + SfGenStrategy_ // SfGenStrategy + > { + using Base = GettEpilogueParams< + ElementScalar_, // ElementScalar + ElementScalar_, // ElementScalingFactor + ElementAccumulator_, // ElementAccumulator + ElementCompute_, // ElementCompute + TensorC_, // TensorC (M, N, L) + TensorD_, // TensorD (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBias (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // TensorAux (M, N, L) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorAlpha (M, 1) + decltype(make_tensor(cute::recast_ptr(nullptr), typename TensorD_::layout_type{})), // VectorBeta (M, 1) + cutlass::epilogue::thread::Identity, // + TensorSfD_, // TensorSfD + SFD_VectorSize_, // SFD_VectorSize + cutlass::plus, // BiasBinaryOp + false, // PerColumnBias + SfGenStrategy_ // SfGenStrategy + >; + using ElementScalar = typename Base::ElementScalar; + using ElementScalingFactor = typename Base::ElementScalingFactor; + using ElementAccumulator = typename Base::ElementAccumulator; + using ElementCompute = typename Base::ElementCompute; + using TensorC = typename Base::TensorC; + using TensorD = typename Base::TensorD; + using TensorAux = typename Base::TensorAux; + using VectorBias = typename Base::VectorBias; + using VectorAlpha = typename Base::VectorAlpha; + using VectorBeta = typename Base::VectorBeta; + using TensorSFD = typename Base::TensorSFD; + using SFD_VectorSize = typename Base::SFD_VectorSize; + using ActivationFunctor = typename Base::ActivationFunctor; + using BiasBinaryOp = typename Base::BiasBinaryOp; + + using EngineC = typename Base::EngineC; + using LayoutC = typename Base::LayoutC; + using EngineD = typename Base::EngineD; + using LayoutD = typename Base::LayoutD; + using EngineSfD = typename Base::EngineSfD; + using LayoutSfD = typename Base::LayoutSfD; + static constexpr bool PerColumnBias = Base::PerColumnBias; + static constexpr SfStrategy SfGenStrategy = Base::SfGenStrategy; + + GettBlockScalingEpilogueParams() {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D) + : Base(alpha, beta, tensor_C, tensor_D) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, ElementCompute{0}) {} + + GettBlockScalingEpilogueParams(ElementScalar alpha, ElementScalar beta, TensorC tensor_C, TensorD tensor_D, TensorSFD tensor_SfD, ElementCompute epilogue_st) + : Base(alpha, beta, tensor_C, tensor_D, tensor_SfD, epilogue_st) {} +}; + + + + + +/////////////////////////////////////////////////////////// +// +// Generic Gett 3x Implementation +// +/////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +void compute_1d_scaling_factor_and_quantized_output( + EpilogueParams const& epilogue_params, + TensorD &tensor_D, + TensorSFD &tensor_SfD, + int64_t m, + int64_t n, + int64_t l, + ElementCompute (&acc)[kBlockM][kBlockN]) +{ + using ElementD = typename ElementTraits::type; + using ElementSfD = typename ElementTraits::type; + + int const M = cute::size<0>(tensor_D.layout()); + int const N = cute::size<1>(tensor_D.layout()); + int const L = cute::size<2>(tensor_D.layout()); + + auto mul = cutlass::multiplies{}; + auto div = divides{}; + // Get FP max + ElementCompute fp_max = ElementCompute(std::numeric_limits::max()); + float scale_down_factor = div(1.0f, fp_max); + // Get st' = st / FP max + ElementCompute st_scaled_down = mul(epilogue_params.st, scale_down_factor); + + absolute_value_op abs_op; + maximum_with_nan_propogation max_op; + + if constexpr (cute::is_constant<1, decltype(cute::stride<0,0,1>(tensor_SfD))>::value) { + // MN major output + int const NumVecPerBlock = ceil_div(kBlockM, kVectorSize); + // Col major output + for (int n_b = 0; n_b < kBlockN; ++n_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t col = n + n_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + accum_max = max_op(accum_max, abs_op(acc[accum_row][n_b])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_row = m + kVectorSize * v_b; + if (sf_row < M && col < N) { + tensor_SfD(sf_row, col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_row = v_b * kVectorSize + v; + int64_t output_row = accum_row + m; + if (output_row < M && col < N) { + acc[accum_row][n_b] = mul(acc[accum_row][n_b], acc_scale); + } + } + } + } + } + else { + int const NumVecPerBlock = ceil_div(kBlockN, kVectorSize); + // row major output + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int v_b = 0; v_b < NumVecPerBlock; ++v_b) { + int64_t row = m + m_b; + + /// Step1: get max across a vector + ElementCompute accum_max = ElementCompute(0); + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + accum_max = max_op(accum_max, abs_op(acc[m_b][accum_col])); + } + } + + /// Step2: Compute Scale + ElementCompute pvscale = mul(accum_max, st_scaled_down); + ElementSfD qpvscale = static_cast(pvscale); + // Store the Scaling Factors + int64_t sf_col = n + kVectorSize * v_b; + + if (row < M && sf_col < N) { + tensor_SfD(row, sf_col, l) = qpvscale; + } + + /// Step3: Compute quantized output values + ElementCompute qpvscale_up = NumericConverter{}(qpvscale); + // Get float reciprocal + ElementCompute qpvscale_rcp = div(1.0f, qpvscale_up); + ElementCompute acc_scale = mul(epilogue_params.st, qpvscale_rcp); + // Map INF to fp32::max + acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scale, cutlass::platform::numeric_limits::max()); + // Store the intermediate_accum + for (int v = 0; v < kVectorSize; v++) { + int accum_col = v_b * kVectorSize + v; + int64_t output_col = accum_col + n; + if (row < M && output_col < N) { + acc[m_b][accum_col] = mul(acc[m_b][accum_col], acc_scale); + } + } + } + } + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - General Tensor-Tensor contraction reference kernel +template < + class MainloopParams, + class EpilogueParams +> +void Gett( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + + static int constexpr kBlockM = 64; + static int constexpr kBlockN = 64; + +#if defined(_OPENMP) + #pragma omp parallel for collapse(3) +#endif + for (int64_t l = 0; l < cute::size<2>(mainloop_params.A.layout()); ++l) { + for (int64_t m = 0; m < cute::size<0>(mainloop_params.A.layout()); m += kBlockM) { + for (int64_t n = 0; n < cute::size<0>(mainloop_params.B.layout()); n += kBlockN) { + typename MainloopParams::ElementAccumulator acc[kBlockM][kBlockN]; + gett_mainloop(mainloop_params, m, n, l, acc); + gett_epilogue(epilogue_params, m, n, l, acc); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Mainloop +template +void gett_mainloop( + MainloopParams const& mainloop_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); + static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; + + + using ElementSFA = typename ElementTraits::type; + using ElementSFB = typename ElementTraits::type; + + + using RingOp = multiply_add; + RingOp fma_op; + + // Zero out accumulators + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Compute on this k-block + for (int64_t k = 0; k < cute::size<1>(mainloop_params.A.layout()); ++k) { + // Load A + ElementAccumulator a_frag[kBlockM]; + for (int m_b = 0; m_b < kBlockM; ++m_b) { + if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFA + auto sfa = static_cast(mainloop_params.SfA(m + m_b, k, l)); + a_frag[m_b] *= sfa; + } + + + if (mainloop_params.transform_A == ComplexTransform::kConjugate) { + a_frag[m_b] = conj(a_frag[m_b]); + } + } else { + a_frag[m_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // Load B + ElementAccumulator b_frag[kBlockN]; + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); + + + if constexpr (not cute::is_same_v){ + // Load SFB + auto sfb = static_cast(mainloop_params.SfB(n + n_b, k, l)); + b_frag[n_b] *= sfb; + } + + + if (mainloop_params.transform_B == ComplexTransform::kConjugate) { + b_frag[n_b] = conj(b_frag[n_b]); + } + } else { + b_frag[n_b] = ElementAccumulator(0); // RingOp::AdditionIdentity + } + } + + // do compute + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + acc[m_b][n_b] = fma_op(a_frag[m_b], b_frag[n_b], acc[m_b][n_b]); + } + } + + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// GETT - Epilogue +template +void gett_epilogue( + EpilogueParams const& epilogue_params, + int64_t m, + int64_t n, + int64_t l, + ElementAccumulator (&acc)[kBlockM][kBlockN]) +{ + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == 3, "M, K, B"); + static_assert(cute::rank(typename EpilogueParams::LayoutD{}) == 3, "N, K, B"); + + using cute::raw_pointer_cast; + + using ElementCompute = typename EpilogueParams::ElementCompute; + using ElementC = typename EpilogueParams::TensorC::value_type; + using ElementD = typename EpilogueParams::TensorD::value_type; + using ElementSfD = typename EpilogueParams::TensorSFD::value_type; + using ElementAux = typename EpilogueParams::TensorAux::value_type; + using ElementBias = typename EpilogueParams::VectorBias::value_type; + using ElementScalar = typename EpilogueParams::ElementScalar; + using ElementScalingFactor = typename EpilogueParams::ElementScalingFactor; + using ActivationFunctor = typename EpilogueParams::ActivationFunctor; + using BiasBinaryOp = typename EpilogueParams::BiasBinaryOp; + + constexpr bool PerColBias = EpilogueParams::PerColumnBias; + constexpr SfStrategy SfGenStrategy = EpilogueParams::SfGenStrategy; + + constexpr bool IsScalingAndAmaxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsScalingAndAmaxAuxOutputNeeded = + cute::is_same_v or + cute::is_same_v; + + constexpr bool IsReLUAuxNeeded = + (cute::is_same_v> or + cute::is_same_v>) and + cute::is_same_v; + constexpr bool UseReLU = + cute::is_same_v>; // Treat Clamp as ReLU + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + + // Input related converter + NumericConverter accumulator_converter; + NumericConverter source_converter; + NumericConverter bias_converter; + [[maybe_unused]] NumericConverter aux_source_converter; + + // Scale related converter + NumericConverter scale_converter; + NumericConverter scaling_factor_converter; + + // Abs max converter + [[maybe_unused]] NumericConverter abs_max_output_converter; + + // Output related converter + NumericConverter destination_converter; + [[maybe_unused]] NumericConverter aux_destination_converter; + NumericConverter dBias_converter; + + // Epilogue operations + multiply_add epilogue_fma; + multiplies mul; + plus add; + + // Activation operation + ActivationFunctor activation; + + // Bias binary operation + BiasBinaryOp bias_op; + + // Do conversion + ElementCompute converted_alpha = scale_converter(epilogue_params.alpha); + ElementCompute converted_beta = scale_converter(epilogue_params.beta); + ElementCompute converted_scale_a = scaling_factor_converter(epilogue_params.scale_a); + ElementCompute converted_scale_b = scaling_factor_converter(epilogue_params.scale_b); + ElementCompute converted_scale_c = scaling_factor_converter(epilogue_params.scale_c); + ElementCompute converted_scale_d = scaling_factor_converter(epilogue_params.scale_d); + ElementCompute converted_scale_aux = scaling_factor_converter(epilogue_params.scale_aux); + + // Init local var + [[maybe_unused]] ElementCompute local_abs_max_output = ElementCompute(0); + [[maybe_unused]] ElementCompute local_abs_max_aux_output = ElementCompute(0); + + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + converted_beta = mul(converted_beta, converted_scale_c); + + ElementCompute inter_accum[kBlockM][kBlockN]; + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + // Convert every type to ElementCompute first, do compute, convert to output type, write it out + ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); + // vector alpha + if (raw_pointer_cast(epilogue_params.Valpha.data())) { + converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b, n + n_b, l)); + converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); + } + ElementCompute output = mul(converted_alpha, converted_acc); + + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { + ElementCompute converted_bias = bias_converter(epilogue_params.Bias(PerColBias ? n + n_b : m + m_b)); + output = bias_op(output, converted_bias); + } + + if (raw_pointer_cast(epilogue_params.C.data())) { + ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); + // vector beta + if (epilogue_params.Vbeta.data()) { + converted_beta = scale_converter(epilogue_params.Vbeta(m + m_b, n + n_b, l)); + converted_beta = mul(converted_beta, converted_scale_c); + } + output = epilogue_fma(converted_beta, converted_src, output); + } + + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); + } + + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); + } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } + + if constexpr (UseReLU) { + cutlass::epilogue::thread::ReLU relu; + output = relu(output); + } + else { + output = activation(output); + } + } + + if constexpr (IsScalingAndAmaxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_output = amax_op(local_abs_max_output, output); + output = epilogue_fma(converted_scale_d, output, ElementCompute(0)); + } + + inter_accum[m_b][n_b] = ElementCompute(output); + } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } + } + } // m_b + + if constexpr ( + SfGenStrategy == SfStrategy::SfDGen + ) { + // 1d scale factor generation + constexpr int kVectorSize = typename EpilogueParams::SFD_VectorSize{}; + if (epilogue_params.SfD.data() != nullptr) { + compute_1d_scaling_factor_and_quantized_output(epilogue_params, epilogue_params.D, epilogue_params.SfD, m, n, l, inter_accum); + } + } + + for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int n_b = 0; n_b < kBlockN; ++n_b) { + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { + epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(inter_accum[m_b][n_b]); + } + } + } + +#if defined(_OPENMP) + #pragma omp critical(Abs_Max_Data_Update) +#endif + { + if constexpr (IsScalingAndAmaxOutputNeeded) { + if (epilogue_params.abs_max_D) { + *epilogue_params.abs_max_D = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_D, abs_max_output_converter(local_abs_max_output)); + } + } + + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + if (epilogue_params.abs_max_Aux) { + *epilogue_params.abs_max_Aux = maximum_with_nan_propogation{}( + *epilogue_params.abs_max_Aux, abs_max_output_converter(local_abs_max_aux_output)); + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +auto make_layout_rank3(const TensorType& tensor) { + // append a batch mode of size 1 if we do not have tensors that are rank 3 + return make_layout( + make_shape(cute::get<0>(tensor.shape()), cute::get<1>(tensor.shape()), cute::Int<1>{}), + make_stride(cute::get<0>(tensor.stride()), cute::get<1>(tensor.stride()), int64_t(cosize(tensor.layout())))); +} + +/// GEMM - General Matrix-Matrix contraction without conjugation options +template < + class MainloopParams, + class EpilogueParams +> +void Gemm3x( + MainloopParams const& mainloop_params, + EpilogueParams const& epilogue_params) +{ + using namespace cute; + + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename MainloopParams::LayoutB{})); + static_assert(cute::rank(typename EpilogueParams::LayoutC{}) == cute::rank(typename EpilogueParams::LayoutD{})); + static_assert(cute::rank(typename MainloopParams::LayoutA{}) == cute::rank(typename EpilogueParams::LayoutC{})); + + if constexpr (cute::rank(typename MainloopParams::LayoutA{}) == 2) { + cute::Layout layout_A = make_layout_rank3(mainloop_params.A); + cute::Layout layout_B = make_layout_rank3(mainloop_params.B); + cute::Layout layout_C = make_layout_rank3(epilogue_params.C); + cute::Layout layout_D = make_layout_rank3(epilogue_params.D); + cute::Layout layout_Aux = make_layout_rank3(epilogue_params.Aux); + cute::Layout layout_Bias = make_layout_rank3(epilogue_params.Bias); + cute::Layout layout_Valpha = make_layout_rank3(epilogue_params.Valpha); + cute::Layout layout_Vbeta = make_layout_rank3(epilogue_params.Vbeta); + + auto TensorA = make_tensor(mainloop_params.A.data(), layout_A); + auto TensorB = make_tensor(mainloop_params.B.data(), layout_B); + auto TensorC = make_tensor(epilogue_params.C.data(), layout_C); + auto TensorD = make_tensor(epilogue_params.D.data(), layout_D); + auto TensorAux = make_tensor(epilogue_params.Aux.data(), layout_Aux); + auto VectorBias = make_tensor(epilogue_params.Bias.data(), layout_Bias); + auto VectorAlpha = make_tensor(epilogue_params.Valpha.data(), layout_Valpha); + auto VectorBeta = make_tensor(epilogue_params.Vbeta.data(), layout_Vbeta); + + // Reconstruct mainloop params + GettMainloopParams + mainloop_params_converted{TensorA, + TensorB, + mainloop_params.transform_A, + mainloop_params.transform_B}; + + // Reconstruct epilogue params + GettEpilogueParams + epilogue_params_converted{epilogue_params.alpha, + epilogue_params.beta, + TensorC, + TensorD, + VectorBias, + TensorAux, + VectorAlpha, + VectorBeta, + epilogue_params.abs_amax_D, + epilogue_params.abs_amax_Aux, + epilogue_params.scale_a, + epilogue_params.scale_b, + epilogue_params.scale_c, + epilogue_params.scale_d, + epilogue_params.scale_aux + }; + + Gett(mainloop_params_converted, epilogue_params_converted); + } + else { + // if we already have a batch mode, just pass it through + Gett(mainloop_params, epilogue_params); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // cutlass::reference::host + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h new file mode 100644 index 0000000000000000000000000000000000000000..67867533d5783b6e0047ac2110dc47adaa277e25 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for Rank 2k update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert( + FillModeC == FillMode::kLower || + FillModeC == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename platform::conditional<(FillModeC == FillMode::kLower), + std::greater_equal, + std::less_equal>::type; + + // Note: batch is ignored. + // Note: M is same as N for Rank 2k update + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < N; row_block += Nblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Nblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < N && col < N && compare_op(row, col)) + { + + // A x B^T + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b_t(cast_if_scalar(b_t)); + + accum[i][j] = inner_product_op(compute_a, compute_b_t, accum[i][j]); + + // B x A^T + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType compute_b(cast_if_scalar(b)); + ComputeType compute_a_t(cast_if_scalar(a_t)); + + accum[i][j] = inner_product_op(compute_b, compute_a_t, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Nblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < N && col < N && + ( (FillModeC == FillMode::kLower && row >= col) || + (FillModeC == FillMode::kUpper && row <= col) ) + ) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Rank 2k update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_rank2k( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_rank2k( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + FillMode FillModeC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Rank2K; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Rank2K { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_rank2k>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..a738101660f7ebbdd7c7796d46df244f1e3f5f70 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_2k_complex.h @@ -0,0 +1,318 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x B^T (Symmetric) or A x B^H (Hermitian) + // complex conjugation on operandB (b_t) is function of blas3 computation + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementB b_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_b.at(MatrixCoord(col, k_block))) : + tensor_b.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(b_t); + + // complex conjugation is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + } + } + } + } + + /* HER2K need two epilogues to handle complex alpha value */ + if ( blas_mode == BlasMode::kHermitian ) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + tensor_d.at(coord) = convert_op(alpha * + ScalarType(accum[i][j]) + + beta * c); + } + } + } + + /* Zeoring out accum for second HERK */ + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // B x A^T (Symmetric) or B x A^H (Hermitian) + // complex conjugation on operandB (a_t) is function of blas3 computation + ElementB b = tensor_b.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))): + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType b_ik = ComputeType(b); + ComputeType a_jk = ComputeType(a_t); + + // complex conjugation here is a function of operand layouts + if (transform_b == ComplexTransform::kConjugate) { + b_ik = conj(b_ik); + } + // complex conjugation here is a function of operand layouts + if (transform_a == ComplexTransform::kConjugate) { + a_jk = conj(a_jk); + } + + accum[i][j] = inner_product_op(b_ik, a_jk, accum[i][j]); + } + } + } + } + + ScalarType alpha_hermitian = (blas_mode == BlasMode::kHermitian) ? + conj(alpha) : alpha; + ScalarType beta_hermitian = (blas_mode == BlasMode::kHermitian) ? + 1 : beta; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType d = (blas_mode == BlasMode::kHermitian) ? + tensor_d.at(coord) : tensor_c.at(coord); + + ScalarType tmp_d = convert_op( + alpha_hermitian * ScalarType(accum[i][j]) + + beta_hermitian * d); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + TensorRef tensor_b, + ComplexTransform transform_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + tensor_b, transform_b, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..1aad33fd643b60752bc0845e403cebc43ad7d047 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/rank_k_complex.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued Rank 2K update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename ConvertOp = NumericConverter, + typename InnerProductOp = multiply_add +> +void Rank2KComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + FillMode fill_mode_c, + BlasMode blas_mode, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + int const K = problem_size.k(); + + // Rank2K update operates on A=NxK, B=NxK, and C=NxN + assert(M==N); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N && + ( (fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col) ) + ) { + + // A x A^T (Symmetric) or A x A^H (Hermitian) + // complex conjugation on operandB (a_t) (function of blas3 computation) + ElementA a = tensor_a.at(MatrixCoord(row, k_block)); + ElementA a_t = (blas_mode == BlasMode::kHermitian) ? + conj(tensor_a.at(MatrixCoord(col, k_block))) : + tensor_a.at(MatrixCoord(col, k_block)); + + ComputeType a_ik = ComputeType(a); + ComputeType b_jk = ComputeType(a_t); + + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } + // complex conjugation (function of input layouts) + if (transform_a == ComplexTransform::kConjugate) { + b_jk = conj(b_jk); + } + + accum[i][j] = inner_product_op(a_ik, b_jk, accum[i][j]); + + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N && + ((fill_mode_c == FillMode::kLower && row >= col) || + (fill_mode_c == FillMode::kUpper && row <= col)) + ) { + + ScalarType c = tensor_c.at(coord); + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (blas_mode == BlasMode::kHermitian) { + c = (row == col) ? real(c) : c; + } + + ScalarType tmp_d = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + + if (blas_mode == BlasMode::kHermitian && row == col ) { + tensor_d.at(coord) = real(tmp_d); + } else { + tensor_d.at(coord) = tmp_d; + } + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// This assumes the accumulator type is the same type as the scalars. +template < + typename ElementA, + typename LayoutA, + typename ElementC, + typename LayoutC, + typename ScalarType +> +void RankKComplex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + ComplexTransform transform_a, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + FillMode fill_mode_c, + BlasMode blas_mode) { + + Rank2KComplex( + problem_size, alpha, + tensor_a, transform_a, + beta, tensor_c, tensor_d, + ScalarType(0), + fill_mode_c, + blas_mode); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h new file mode 100644 index 0000000000000000000000000000000000000000..34f9648f25f8965f6730999b7763220c360683a8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm.h @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for SYMM update in host-side code. + + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + FillModeA == FillMode::kLower || + FillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (SideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + + ComputeType compute_a_1(cast_if_scalar(a_1)); + ComputeType compute_b_1(cast_if_scalar(b_1)); + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (SideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + } + + ComputeType compute_a_2(cast_if_scalar(a_2)); + ComputeType compute_b_2(cast_if_scalar(b_2)); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * ScalarType(tensor_c.at(coord))); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general Symm update (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum) { + compute_symm( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_c, + initial_accum); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Symm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Symm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, initial_accum); + } + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutB::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..79e146f69b784a92ce61a093f410e93a66005cf8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/symm_complex.h @@ -0,0 +1,319 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued SYMM update in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include + +namespace cutlass { +namespace reference { +namespace host { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes a general matrix product among matrices (tensors of rank=2) pointed to by TensorRef +/// objects. +/// +/// Explicitly naming types needed by this template can be cumbersome, particularly for the +/// accumulator type, so a function argument 'initial_accum' is exposed. Passing +/// AccumulatorType(0) as the last function argument can be easier than naming all template +/// arguments explicitly. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = BlasMode::kSymmetric, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_symm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0) { + + static SideMode const kSideModeA = SideModeA; + static FillMode const kFillModeA = FillModeA; + static BlasMode const kBlasMode = BlasMode_; + + static_assert( + LayoutA::kRank == 2 && + LayoutB::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(kSideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert( + kFillModeA == FillMode::kLower || + kFillModeA == FillMode::kUpper, + "Fill Mode can either be Lower or Upper."); + + using CompareOp_w_diag = typename TrMatrixCompareOp::Type; + using CompareOp_wo_diag = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp_w_diag compare_op_1; + CompareOp_wo_diag compare_op_2; + + for (int batch_idx = 0; batch_idx < batch_count; ++batch_idx) { + + // Compute matrix product using blocks + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) + { + ElementA a_1 = ElementA(); + ElementB b_1 = ElementB(); + ElementA a_2 = ElementA(); + ElementB b_2 = ElementB(); + + // A x B or B x A (with diagonal) + if (kSideModeA == SideMode::kLeft) { + a_1 = (compare_op_1(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(); + b_1 = tensor_b.at(MatrixCoord(k_block, col)); + } else if (kSideModeA == SideMode::kRight) { + a_1 = tensor_b.at(MatrixCoord(row, k_block)); + b_1 = (compare_op_1(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(); + } + ComputeType compute_a_1 = ComputeType(a_1); + ComputeType compute_b_1 = ComputeType(b_1); + + // The imaginary parts of the diagonal elements of + // a complex data type are assumed and set to zero + if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kLeft && row == k_block) { + compute_a_1 = real(compute_a_1); + } else if (kBlasMode == BlasMode::kHermitian && kSideModeA == SideMode::kRight && k_block == col) { + compute_b_1 = real(compute_b_1); + } + + accum[i][j] = inner_product_op(compute_a_1, compute_b_1, accum[i][j]); + + // A^T x B or B x A^T (without diagonal) + if (kSideModeA == SideMode::kLeft) { + a_2 = (compare_op_2(k_block, row)) ? + (tensor_a.at(MatrixCoord(k_block, row))) : ElementA(); + b_2 = tensor_b.at(MatrixCoord(k_block, col)); + if (kBlasMode == BlasMode::kHermitian) + a_2 = conj(a_2); + } else if (kSideModeA == SideMode::kRight) { + a_2 = tensor_b.at(MatrixCoord(row, k_block)); + b_2 = (compare_op_2(col, k_block)) ? + tensor_a.at(MatrixCoord(col, k_block)) : ElementA(); + if (kBlasMode == BlasMode::kHermitian) + b_2 = conj(b_2); + } + + ComputeType compute_a_2 = ComputeType(a_2); + ComputeType compute_b_2 = ComputeType(b_2); + + accum[i][j] = inner_product_op(compute_a_2, compute_b_2, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + + ScalarType c = tensor_c.at(coord); + + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j]) + + beta * c); + } + } + } + + } // for (col_block) + } // for (row_block) + + tensor_a.add_pointer_offset(batch_stride_A); + tensor_b.add_pointer_offset(batch_stride_B); + tensor_c.add_pointer_offset(batch_stride_C); + tensor_d.add_pointer_offset(batch_stride_D); + + } // for (batch_idx) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + BlasMode BlasMode_ = cutlass::BlasMode::kSymmetric, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct SymmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct SymmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, ScalarType beta, + TensorRef tensor_c, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_symm_complex>( + problem_size, alpha, tensor_a, tensor_b, beta, tensor_c, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h new file mode 100644 index 0000000000000000000000000000000000000000..d6b85ca1baf65ba811b7c8b3a224ca90bbce1680 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.h @@ -0,0 +1,616 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/relatively_equal.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorGreatestErrorFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double result; + + /// Ctor + TensorGreatestErrorFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + result(0.0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + result = std::max(result, std::abs(double(lhs_) - double(rhs_))); + } + + /// Returns true if equal + operator double() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMREFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + static constexpr double epsilon = 1e-6; + + /// Ctor + TensorMREFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::abs(double(lhs_) - double(rhs_) / (double(rhs_) + epsilon)); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorMSEFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + double sum; + uint64_t count; + + /// Ctor + TensorMSEFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), + rhs(rhs_), + sum(0.0), + count(0) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + sum += std::pow((double(lhs_) - double(rhs_)), 2); + ++count; + } + + /// Returns true if equal + operator double() const { + return sum / double(count); + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + bool result; + + /// Ctor + TensorEqualsFunc(): result(true) { } + + /// Ctor + TensorEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_ + ) : + lhs(lhs_), rhs(rhs_), result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (lhs_ != rhs_) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorRelativelyEqualsFunc { + + // + // Data members + // + + TensorView lhs; + TensorView rhs; + Element epsilon; + Element nonzero_floor; + bool result; + + /// Ctor + TensorRelativelyEqualsFunc( + TensorView const &lhs_, + TensorView const &rhs_, + Element epsilon_, + Element nonzero_floor_ + ) : + lhs(lhs_), + rhs(rhs_), + epsilon(epsilon_), + nonzero_floor(nonzero_floor_), + result(true) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + Element lhs_ = lhs.at(coord); + Element rhs_ = rhs.at(coord); + + if (!relatively_equal(lhs_, rhs_, epsilon, nonzero_floor)) { + result = false; + } + } + + /// Returns true if equal + operator bool() const { + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Squared Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMSE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMSEFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the Mean Relative Error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorMRE( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorMREFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns the greatest error between two tensors. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +double TensorGreatestError( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return -1; + } + + detail::TensorGreatestErrorFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return double(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()} + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorView const &lhs, + TensorView const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc func(lhs, rhs, epsilon, nonzero_floor); + TensorForEach( + lhs.extent(), + func + ); + + return bool(func); +} + +/// Returns true if two tensor views are relatively equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorRelativelyEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs, + Element epsilon, + Element nonzero_floor) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return false; + } + + detail::TensorRelativelyEqualsFunc real_func( + {lhs.data(), lhs.layout(), lhs.extent()}, + {rhs.data(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + real_func + ); + + if (!bool(real_func)) { + return false; + } + + detail::TensorEqualsFunc imag_func( + {lhs.data() + lhs.imaginary_stride(), lhs.layout(), lhs.extent()}, + {rhs.data() + rhs.imaginary_stride(), rhs.layout(), rhs.extent()}, + epsilon, + nonzero_floor + ); + + TensorForEach( + lhs.extent(), + imag_func + ); + + return bool(imag_func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are NOT equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorView const &lhs, + TensorView const &rhs) { + + // Extents must be identical + if (lhs.extent() != rhs.extent()) { + return true; + } + + detail::TensorEqualsFunc func(lhs, rhs); + TensorForEach( + lhs.extent(), + func + ); + + return !bool(func); +} + +/// Returns true if two tensor views are equal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorNotEquals( + TensorViewPlanarComplex const &lhs, + TensorViewPlanarComplex const &rhs) { + + return !TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorContainsFunc { + + // + // Data members + // + + TensorView view; + Element value; + bool contains; + Coord location; + + // + // Methods + // + + /// Ctor + TensorContainsFunc(): contains(false) { } + + /// Ctor + TensorContainsFunc( + TensorView const &view_, + Element value_ + ) : + view(view_), value(value_), contains(false) { } + + /// Visits a coordinate + void operator()(Coord const &coord) { + + if (view.at(coord) == value) { + if (!contains) { + location = coord; + } + contains = true; + } + } + + /// Returns true if equal + operator bool() const { + return contains; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if a value is present in a tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +bool TensorContains( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return bool(func); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns a pair containing a boolean of whether a value exists in a tensor and the location of +/// of the first occurrence. If the value is not contained in the tensor, the second element of the +/// pair is undefined. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +std::pair > TensorFind( + TensorView const & view, + Element value) { + + detail::TensorContainsFunc func( + view, + value + ); + + TensorForEach( + view.extent(), + func + ); + + return std::make_pair(bool(func), func.location); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp new file mode 100644 index 0000000000000000000000000000000000000000..27ef969b4ff2b6d8f3a53f3d1a3e5ec3e5203ec3 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_compare.hpp @@ -0,0 +1,101 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Returns true if two tensor views are equal. +template < + typename TensorL, + typename TensorR +> +bool TensorEquals( + TensorL lhs, + TensorR rhs) { + + // Extents must be identical + if (cute::size(lhs) != cute::size(rhs)) { + return false; + } + + for (int64_t idx = 0; idx < cute::size(lhs); ++idx) { + if (lhs(idx) != rhs(idx)) { + return false; + } + } + + return true; +} + +/// Returns true if two tensor views are NOT equal. +template < + typename TensorL, + typename TensorR +> +bool TensorNotEquals( + TensorL lhs, + TensorR rhs) { + + return TensorEquals(lhs, rhs); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h new file mode 100644 index 0000000000000000000000000000000000000000..d2a43b1295c8ab18c7d649c79b0364b6d3e7c48c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_copy.h @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Standard Library includes +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to convert between types +template < + typename DstElement, + typename SrcElement +> +struct TrivialConvert { + + TrivialConvert() { } + + DstElement operator()(SrcElement src) const { + return DstElement(src); + } +}; + +/// Helper to conditionally copy between tensor views. +template < + typename DstElement, + typename DstLayout, + typename SrcElement, + typename SrcLayout, + typename F +> +struct TensorCopyIf { + + using DstTensorView = TensorView; + using SrcTensorView = TensorView; + + // + // Data members + // + + DstTensorView dst; + SrcTensorView src; + F convert; + + // + // Methods + // + + TensorCopyIf() { } + + TensorCopyIf( + DstTensorView const &dst_, + SrcTensorView const &src_, + F const &convert_): dst(dst_), src(src_), convert(convert_) {} + + /// Copies based on destination and source bounds + void operator()(Coord const &coord) { + if (dst.contains(coord) && src.contains(coord)) { + dst.at(coord) = convert(src.at(coord)); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + CopyIf copy_if(dst, src, transform); + + TensorForEach(dst.extent(), copy_if); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView src_view(src, dst.extent()); + + CopyIf copy_if(dst, src_view, transform); + + TensorForEach(dst.extent(), copy_if); +} + +/// Copies elements from a TensorRef into a TensorView. Assumes source tensor has sufficient extent +/// to avoid out of bounds accesses. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorRef dst, + TensorView src, + F const &transform) { + + using CopyIf = detail::TensorCopyIf< + DstElement, + DstLayout, + SrcElement, + SrcLayout, + F>; + + TensorView dst_view(dst, src.extent()); + + CopyIf copy_if(dst_view, src, transform); + + TensorForEach(src.extent(), copy_if); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorView dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout, /// Source tensor's layout + typename F /// Transformation functor +> +void TensorCopy( + TensorView dst, + TensorRef src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies elements from one tensor view into another, satisfying bounds of each tensor. Succeeds +/// if SrcElement can be converted to DstElement. +template < + typename DstElement, /// Destination tensor's element type + typename DstLayout, /// Destination tensor's layout + typename SrcElement, /// Source tensor's element type + typename SrcLayout /// Source tensor's layout +> +void TensorCopy( + TensorRef dst, + TensorView src) { + + detail::TrivialConvert convert; + + TensorCopy(dst, src, convert); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h new file mode 100644 index 0000000000000000000000000000000000000000..5470df29358799f6d5e6628e8722f0e3dc05485f --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_elementwise.h @@ -0,0 +1,341 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines host-side elementwise operations on TensorView. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" + +#include "tensor_foreach.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to apply a binary operator in place +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementD, + typename LayoutD, + typename BinaryFunc> +struct TensorFuncBinaryOp { + + // + // Data members + // + + /// View of left-hand-side tensor + TensorView view_d; + TensorRef view_a; + TensorRef view_b; + BinaryFunc func; + + // + // Methods + // + + /// Constructor + TensorFuncBinaryOp() { } + + /// Constructor + TensorFuncBinaryOp( + TensorView const & view_d_, + TensorRef const & view_a_, + TensorRef const & view_b_, + BinaryFunc func = BinaryFunc() + ): + view_d(view_d_), view_a(view_a_), view_b(view_b_), func(func) { } + + /// Equality check + void operator()(Coord const &coord) const { + view_d.at(coord) = func( + ElementD(view_a.at(coord)), + ElementD(view_b.at(coord)) + ); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Adds two tensors and stores in the destination tensor: d = a + b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::plus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Adds a tensor in place: d = d .+ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorAdd( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorAdd(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Subtracts two tensors and stores in the destination tensor: d = a - b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference + ) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::minus + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Subtracts two tensors in place: d = d .- a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorSub( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference + ) { + + TensorSub(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Multiplies two tensors and stores in the destination tensor: d = a .* b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::multiplies + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Multiplies tensors in place: d = d .* a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorMul( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorMul(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorDiv( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Divides two tensors and stores in the destination tensor: d = a ./ b +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a, ///< A tensor reference + TensorRef b ///< B tensor reference +) { + + detail::TensorFuncBinaryOp< + ElementD, + LayoutD, + ElementA, + LayoutA, + ElementB, + LayoutB, + cutlass::divides + > func(d, a, b); + + TensorForEach( + d.extent(), + func); +} + +/// Divides tensors in place: d = d ./ a +template < + typename ElementD, + typename LayoutD, + typename ElementA, + typename LayoutA +> +void TensorModulus( + TensorView d, ///< destination tensor view + TensorRef a ///< A tensor reference +) { + TensorDiv(d, d, a); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h new file mode 100644 index 0000000000000000000000000000000000000000..645902f7dd7b62bc98a479e4956dfb4b437d46a7 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -0,0 +1,1718 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include +#include +#include + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/subbyte_reference.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/blas3.h" + +#include "cutlass/util/distribution.h" +#include "tensor_foreach.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element value; + + // + // Methods + // + + TensorFillFunc( + TensorView const &view_ = TensorView(), + Element value_ = Element(0) + ): view(view_), value(value_) { } + + void operator()(Coord const & coord) const { + view.at(coord) = value; + } +}; + +/// Returns a pair of values of the Gaussian distribution generated by the Box Muller method +struct BoxMullerFunc { + + BoxMullerFunc() {} + + void operator()( + double* rnd, ///< Size-2 vector to be filled with random values + double mean = 0, ///< Mean of the Gaussian distribution + double stddev = 1, ///< Standard deviation of the Gaussian distribution + double pi = std::acos(-1)) const { + + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + rnd[0] = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd[1] = std::sqrt(-2 * std::log(u1)) * std::sin(2 * pi * u2); + rnd[0] = mean + stddev * rnd[0]; + rnd[1] = mean + stddev * rnd[1]; + } +}; +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorView dst, ///< destination tensor + Element val = Element(0)) { ///< value to uniformly fill it with + + detail::TensorFillFunc func(dst, val); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with a uniform value +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFill( + TensorViewPlanarComplex dst, ///< destination tensor + cutlass::complex val = cutlass::complex(0)) { ///< value to uniformly fill it with + + TensorFill(dst.view_real(), val.real()); + TensorFill(dst.view_imag(), val.imag()); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + } + else { + result = static_cast(0); + } + + // Note that exclude_zero = true will disable the bernoulli_result above by unsetting zeros + if (exclude_zero && result == Element(0)) { + if (rnd > 0) { + rnd += 1; + } else { + rnd -= 1; + } + result = Element(rnd); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + double rnd[2]; + detail::BoxMullerFunc func; + func(rnd, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd[0] = double(std::llround(rnd[0] * double(1 << int_scale))); + rnd[1] = double(std::llround(rnd[1] * double(1 << int_scale))); + reals[0] = from_real(rnd[0] / double(1 << int_scale)); + reals[1] = from_real(rnd[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd[0]); + reals[1] = from_real(rnd[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0.0) && + reals[1] == from_real(0.0)) { + + if (rnd[0] > 0.0) { + rnd[0] += 1.0; + } else { + rnd[0] -= 1.0; + } + reals[0] = from_real(rnd[0]); + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomGaussianFunc > { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + double pnz; + bool exclude_zero; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1, + double pnz_ = 1.0, + bool exclude_zero_ = false + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)), pnz(pnz_), exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + double rnd1[2]; + double rnd2[2]; + detail::BoxMullerFunc func; + func(rnd1, mean, stddev, pi); + func(rnd2, mean, stddev, pi); + + // Sample from the Bernoulli distribution, and use the result to sample from the Gaussian + std::random_device rnd_device; + std::mt19937 bernoulli_rnd(rnd_device()); + std::bernoulli_distribution bernoulli_dist(pnz); + bool bernoulli_result = bernoulli_dist(bernoulli_rnd); + + // Sample from the Gaussian distribution for a nonzero element + if (bernoulli_result) { + if (int_scale >= 0) { + rnd1[0] = double(std::llround(rnd1[0] * double(1 << int_scale))); + rnd1[1] = double(std::llround(rnd1[1] * double(1 << int_scale))); + rnd2[0] = double(std::llround(rnd2[0] * double(1 << int_scale))); + rnd2[1] = double(std::llround(rnd2[1] * double(1 << int_scale))); + + reals[0] = from_real(rnd1[0] / double(1 << int_scale)); + reals[1] = from_real(rnd1[1] / double(1 << int_scale)); + reals[2] = from_real(rnd2[0] / double(1 << int_scale)); + reals[3] = from_real(rnd2[1] / double(1 << int_scale)); + } + else { + reals[0] = from_real(rnd1[0]); + reals[1] = from_real(rnd1[1]); + reals[2] = from_real(rnd2[0]); + reals[3] = from_real(rnd2[1]); + } + } + else { + reals[0] = from_real(0); + reals[1] = from_real(0); + reals[2] = from_real(0); + reals[3] = from_real(0); + } + + // Note that this will invalidate the above else statement because it unsets zero elements + if (exclude_zero && + reals[0] == from_real(0) && + reals[1] == from_real(0) && + reals[2] == from_real(0) && + reals[3] == from_real(0)) { + + if (rnd1[0] > 0.0) { + rnd1[0] += 1.0; + } else { + rnd1[0] -= 1.0; + } + reals[0] = from_real(rnd1[0]); + } + + return Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random Gaussian distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + view.at(coord) = func(); + } +}; + +/// Computes a random Gaussian distribution for a rank-2 tensor +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricGaussianFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomGaussianFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillSymmetricGaussianFunc( + TensorView view_ = TensorView(), + RandomGaussianFunc func_ = RandomGaussianFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz, exclude_zero); + + detail::TensorFillGaussianFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomGaussian( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0, /// are not truncated to zero. Permits reducing precision of + /// data. + bool exclude_zero = false) { ///< Exclude zeros from tensor init. + + TensorFillRandomGaussian(dst.view_real(), seed, mean, stddev, bits, pnz); + TensorFillRandomGaussian(dst.view_imag(), ~seed, mean, stddev, bits, pnz); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a Gaussian distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomGaussian( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + detail::TensorFillSymmetricGaussianFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1, ///< If non-negative, specifies number of fractional bits that + double pnz = 1.0) { /// are not truncated to zero. Permits reducing precision of + /// data. + + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits, pnz); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) + { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + Element operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + if (exclude_zero && result == Element(0)) { + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + bool exclude_zero; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0, + bool exclude_zero_ = false + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_) + , bernoulli_rnd{static_cast(seed_)} + , bernoulli_dist(pnan_) + , exclude_zero(exclude_zero_) { + std::srand((unsigned)seed); + + // Handle cases where min = 0 or max = 0 for excluding zeros + if (exclude_zero) { + min = (min == 0.0) ? min + 1: min; + range = (max == 0.0) ? range - 1: range; + } + } + + + /// Compute random value and update RNG state + complex operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + + if (exclude_zero && + i == 0 && + reals[0] == from_real(0.0)) { + + if (rnd > 0.0) { + rnd = std::min(min + range, rnd + 1.0); + } else { + rnd = std::max(min, rnd - 1.0); + } + reals[0] = from_real(Real(rnd)); + } + + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + double pnan; +private: + using engine_type = std::mt19937; +public: + engine_type bernoulli_rnd; + std::bernoulli_distribution bernoulli_dist; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1, + double pnan_ = 0 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_), pnan(pnan_), + bernoulli_rnd{static_cast(seed_)}, + bernoulli_dist(pnan_) + { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() { + + // Sample from NaN distribution. + if constexpr (std::numeric_limits::has_quiet_NaN) { + if (pnan > 0 && bernoulli_dist(bernoulli_rnd)) { + return Element(NAN); + } + } + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(std::llround(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +/// Computes a random uniform distribution +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + + view.at(coord) = func(); + } +}; + +/// Fills the upper or lower part of a symmetric rank-2 tensor with random values of a uniform distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillSymmetricRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillSymmetricRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid + ): + view(view_), func(func_), fill_mode(fill_mode_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kLower && + coord[0] >= coord[1]) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + coord[0] <= coord[1]) { + view.at(coord) = func(); + } + } +}; + +/// Computes a random Uniform distribution and pads diagonal with zeros +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillPadDiagonalRandomUniformFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomUniformFunc func; + cutlass::FillMode fill_mode; + int alignment; + + // + // Methods + // + + /// Construction of uniform RNG functor. + TensorFillPadDiagonalRandomUniformFunc( + TensorView view_ = TensorView(), + RandomUniformFunc func_ = RandomUniformFunc(), + cutlass::FillMode fill_mode_ = cutlass::FillMode::kInvalid, + int alignment_ = 1 + ): + view(view_), func(func_), fill_mode(fill_mode_), alignment(alignment_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) { + // Fill half of matrix based on FillMode + if (Layout::kRank == 2 && + (fill_mode == cutlass::FillMode::kLower) && + (coord[0] >= coord[1]) || + ((coord[1] - coord[0]) >= alignment)) { + view.at(coord) = func(); + } else if (Layout::kRank == 2 && + fill_mode == cutlass::FillMode::kUpper && + (coord[0] <= coord[1]) || + ((coord[0] - coord[1]) >= alignment)) { + view.at(coord) = func(); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan, exclude_zero); + + detail::TensorFillRandomUniformFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values of a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorViewPlanarComplex dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0, ///< Percentage of NaN elements. + bool exclude_zero = false) { ///< Exclude zero from tensor init + + TensorFillRandomUniform(dst.view_real(), seed, max, min, bits, pnan, exclude_zero); + TensorFillRandomUniform(dst.view_imag(), ~seed, max, min, bits, pnan, exclude_zero); +} + + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomUniform( + TensorView, Layout> dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc> random_func(seed, max, min, bits); + + detail::TensorFillRandomUniformFunc, Layout> func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSymmetricRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillSymmetricRandomUniformFunc func( + dst, + random_func, + fill_mode + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/// Fills a tensor with random values with a uniform random distribution pads zeros along diagonal +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillPadDiagonalRandomUniform( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + cutlass::FillMode fill_mode, ///< FillMode for symmetric matrices + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + int alignment = 1 +) { + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + detail::TensorFillPadDiagonalRandomUniformFunc func( + dst, + random_func, + fill_mode, + alignment + ); + + TensorForEach( + dst.extent(), + func + ); +} +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a uniform value +template < + typename Element ///< Element type +> +void BlockFill( + Element *ptr, + size_t capacity, + Element val + ) { + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = val; + } +} + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1, ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + double pnan = 0) { ///< Percentage of NaN elements. + detail::RandomUniformFunc random_func(seed, max, min, bits, pnan); + + for (size_t i = 0; i < capacity; ++i) { + ReferenceFactory::get(ptr, i) = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element diag; + Element other; + + // + // Methods + // + + TensorFillDiagonalFunc( + TensorView const &view_ = TensorView(), + Element diag_ = Element(1), + Element other_ = Element(0) + ): + view(view_), diag(diag_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + view.at(coord) = (is_diag ? diag : other); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor everywhere with a unique value for its diagonal. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillDiagonal( + TensorView dst, ///< destination tensor + Element diag = Element(1), ///< value to write in the diagonal + Element other = Element(0)) { ///< value to write off the diagonal + + detail::TensorFillDiagonalFunc func( + dst, + diag, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to fill a tensor's diagonal with 1 and 0 everywhere else. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillIdentity( + TensorView dst) { ///< destination tensor + + TensorFillDiagonal(dst, Element(1), Element(0)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to the diagonal of a tensor without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateDiagonal( + TensorView dst, ///< destination tensor + Element val = Element(1)) { + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = val; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorUpdateOffDiagonalFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Element other; + + // + // Methods + // + + TensorUpdateOffDiagonalFunc( + TensorView const &view_ = TensorView(), + Element other_ = Element(0) + ): + view(view_), other(other_) { } + + void operator()(Coord const & coord) const { + bool is_diag = true; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + if (coord[i] != coord[i - 1]) { + is_diag = false; + break; + } + } + + if (!is_diag) { + view.at(coord) = other; + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Writes a uniform value to all elements in the tensor without modifying diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorUpdateOffDiagonal( + TensorView dst, ///< destination tensor + Element other = Element(1)) { + + detail::TensorUpdateOffDiagonalFunc func( + dst, + other + ); + + TensorForEach( + dst.extent(), + func + ); +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillLinearFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + Array v; + Element s; + + // + // Methods + // + + TensorFillLinearFunc() { } + + /// Constructs functor + TensorFillLinearFunc( + TensorView const &view_, + Array const & v_, + Element s_ = Element(0) + ): + view(view_), v(v_), s(s_) { } + + /// Updates the tensor + void operator()(Coord const & coord) const { + + Element sum(s); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Layout::kRank; ++i) { + sum += Element(coord[i]) * v[i]; + } + + view.at(coord) = sum; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillLinear( + TensorView dst, ///< destination tensor + Array const & v, + Element s = Element(0)) { + + detail::TensorFillLinearFunc func( + dst, + v, + s + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills tensor with a linear combination of its coordinate and another vector +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillSequential( + TensorView dst, ///< destination tensor + Element s = Element(0)) { + + Array stride; + + stride[0] = Element(1); + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < Layout::kRank; ++i) { + stride[i] = stride[i - 1] * Element(dst.extent()[i - 1]); + } + + TensorFillLinear(dst, stride, s); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values from a distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandom( + TensorView view, ///< destination tensor + uint64_t seed, + Distribution dist, + bool exclude_zero = false ///< If true, excludes 0. + /// Note that setting this flag will result in more 1's, + /// as we use a simple mechanism to replace 0's by adding/subtracting 1's. +) { + + using Real = typename RealType::Type; + + if (dist.kind == Distribution::Gaussian) { + TensorFillRandomGaussian( + view, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz, + exclude_zero); + } else if (dist.kind == Distribution::Uniform) { + TensorFillRandomUniform( + view, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan, + exclude_zero); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = s; + + s = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + cutlass::ReferenceFactory::value < + 8)>::get(ptr, i) = Element(s); + + s = int64_t(s + v) % mod; + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillRandom( + Element *ptr, + size_t capacity, + uint64_t seed, + Distribution dist) { + + if (dist.kind == Distribution::Gaussian) { + BlockFillRandomGaussian( + ptr, + capacity, + seed, + dist.gaussian.mean, + dist.gaussian.stddev, + dist.int_scale, + dist.gaussian.pnz); + } + else if (dist.kind == Distribution::Uniform) { + BlockFillRandomUniform( + ptr, + capacity, + seed, + dist.uniform.max, + dist.uniform.min, + dist.int_scale, + dist.uniform.pnan); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomSparseMetaFunc { + + uint64_t seed; + int range; + int MetaSizeInBits; + + // + // Methods + // + + RandomSparseMetaFunc( + uint64_t seed_ = 0, + int MetaSizeInBits_ = 2 + ): + seed(seed_), MetaSizeInBits(MetaSizeInBits_) { + std::srand((unsigned)seed); + if (MetaSizeInBits_ == 2) { + range = 6; + } + else if (MetaSizeInBits_ == 4) { + range = 2; + } + else { + throw std::invalid_argument("Invalid MetaSizeInBits"); + } + } + + /// Compute random value and update RNG state + Element operator()() const { + Element FourToTwoMeta[6] = {0x4, 0x8, 0x9, 0xc, 0xd, 0xe}; + Element TwoToOneMeta[2] = {0x4, 0xe}; + + Element * MetaArray = (MetaSizeInBits == 2) ? FourToTwoMeta : TwoToOneMeta; + + Element result = 0x0; + + for (int i = 0; i < cutlass::sizeof_bits::value / 4; ++i) { + int rnd = std::rand() % range; + Element meta = MetaArray[rnd]; + + result = (Element)(result | ((Element)(meta << (i * 4)))); + } + + return result; + } +}; + +/// Computes a random sparse meta +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +struct TensorFillRandomSparseMetaFunc { + + using TensorView = TensorView; + + // + // Data members + // + + TensorView view; + RandomSparseMetaFunc func; + + // + // Methods + // + + /// Construction of Gaussian RNG functor. + TensorFillRandomSparseMetaFunc( + TensorView view_ = TensorView(), + RandomSparseMetaFunc func_ = RandomSparseMetaFunc() + ): + view(view_), func(func_) { + + } + + /// Compute random value and update RNG state + void operator()(Coord const &coord) const { + + view.at(coord) = func(); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomSparseMeta( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4 bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + detail::TensorFillRandomSparseMetaFunc func( + dst, + random_func + ); + + TensorForEach( + dst.extent(), + func + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomSparseMeta( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + int MetaSizeInBits) { ///< 2 bit or 4bit + + detail::RandomSparseMetaFunc random_func(seed, MetaSizeInBits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a ell block index matrix with random values with a uniform random distribution. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorFillRandomEllIdx( + TensorView dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + int rows, int ell_cols, int cols) { ///< dimension of the matrix + + std::srand((unsigned)seed); + + for (int i = 0; i < rows; ++i) { + int col_idx = std::rand() % cols; + + for (int j = 0; j < ell_cols; ++j) { + dst.at({i, j}) = col_idx; + + if (col_idx != -1) { + if (col_idx == (cols - 1)) { + col_idx = -1; + } else { + col_idx = std::rand() % (cols - col_idx - 1) + col_idx + 1; + } + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies a diagonal in from host memory without modifying off-diagonal elements. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalIn( + TensorView dst, ///< destination tensor + Element const *ptr) { ///< dense buffer of elements + + typename Layout::Index extent = dst.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + dst.at(coord) = ReferenceFactory::get(ptr, i); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Copies the diagonal of a tensor into a dense buffer in host memory. +template < + typename Element, ///< Element type + typename Layout> ///< Layout function +void TensorCopyDiagonalOut( + Element *ptr, ///< dense buffer of elements + TensorView src) { ///< source tensor + + typename Layout::Index extent = src.extent().min(); + + for (typename Layout::Index i = 0; i < extent; ++i) { + Coord coord(i); + ReferenceFactory::get(ptr, i) = src.at(coord); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1b3df239a1b9d69fc12e7ec4be2de6f87b3a0e3c --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_fill.hpp @@ -0,0 +1,432 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Uniform and procedural tensor fills +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with a scalar element +template +void TensorFill(Tensor dst, typename Tensor::value_type element) { + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = element; + } +} + +/// Fills a tensor with the contents of its layout +template +void TensorFillSequential(Tensor dst) { + + auto layout = dst.layout(); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = layout(idx); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random uniform values +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomUniformFunc { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Element operator()() const { + + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(Real(rnd)); + } + else { + result = static_cast(Real(rnd)); + } + + return result; + } +}; + +/// Partial specialization for initializing a complex value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + complex operator()() const { + + Element reals[2]; + + for (int i = 0; i < 2; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return complex(reals[0], reals[1]); + } +}; + +/// Partial specialization for initializing a Quaternion value. +template +struct RandomUniformFunc > { + + using Real = typename RealType::Type; + + uint64_t seed; + double range; + double min; + int int_scale; + + // + // Methods + // + + RandomUniformFunc( + uint64_t seed_ = 0, + double max = 1, + double min_ = 0, + int int_scale_ = -1 + ): + seed(seed_), range(max - min_), min(min_), int_scale(int_scale_) { + std::srand((unsigned)seed); + } + + + /// Compute random value and update RNG state + Quaternion operator()() const { + + Element reals[4]; + + for (int i = 0; i < 4; ++i) { + double rnd = double(std::rand()) / double(RAND_MAX); + + rnd = min + range * rnd; + + // Random values are cast to integer after scaling by a power of two to facilitate error + // testing + + if (int_scale >= 0) { + rnd = double(int(rnd * double(1 << int_scale))); + reals[i] = from_real(Real(rnd / double(1 << int_scale))); + } + else { + reals[i] = from_real(Real(rnd)); + } + } + + return make_Quaternion(reals[0], reals[1], reals[2], reals[3]); + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a uniform random distribution. +template ///< Tensor object +void TensorFillRandomUniform( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a uniform random distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomUniform( + Element *ptr, + size_t capacity, + uint64_t seed, ///< seed for RNG + double max = 1, ///< upper bound of distribution + double min = 0, ///< lower bound for distribution + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + detail::RandomUniformFunc random_func(seed, max, min, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Random Gaussian +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct RandomGaussianFunc { + + uint64_t seed; + double mean; + double stddev; + int int_scale; + double pi; + + // + // Methods + // + RandomGaussianFunc( + uint64_t seed_ = 0, + double mean_ = 0, + double stddev_ = 1, + int int_scale_ = -1 + ): + seed(seed_), mean(mean_), stddev(stddev_), int_scale(int_scale_), pi(std::acos(-1)) { + std::srand((unsigned)seed); + } + + /// Compute random value and update RNG state + Element operator()() const { + + // Box-Muller transform to generate random numbers with Normal distribution + double u1 = double(std::rand()) / double(RAND_MAX); + double u2 = double(std::rand()) / double(RAND_MAX); + + // Compute Gaussian random value + double rnd = std::sqrt(-2 * std::log(u1)) * std::cos(2 * pi * u2); + rnd = mean + stddev * rnd; + + // Scale and convert final result + Element result; + + if (int_scale >= 0) { + rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); + result = static_cast(rnd); + } + else { + result = static_cast(rnd); + } + + return result; + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a tensor with random values with a Gaussian distribution. +template < + typename Tensor +> +void TensorFillRandomGaussian( + Tensor dst, ///< destination tensor + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (int64_t idx = 0; idx < cute::size(dst); ++idx) { + dst(idx) = random_func(); + } +} + +/// Fills a block with random values with a Gaussian distribution. +template < + typename Element ///< Element type +> +void BlockFillRandomGaussian( + Element *ptr, ///< destination buffer + size_t capacity, ///< number of elements + uint64_t seed, ///< seed for RNG + double mean = 0, ///< Gaussian distribution's mean + double stddev = 1, ///< Gaussian distribution's standard deviation + int bits = -1) { ///< If non-negative, specifies number of fractional bits that + /// are not truncated to zero. Permits reducing precision of + /// data. + + detail::RandomGaussianFunc random_func(seed, mean, stddev, bits); + + for (size_t i = 0; i < capacity; ++i) { + ptr[i] = random_func(); + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequential( + Element *ptr, + int64_t capacity, + Element v = Element(1), + Element s = Element(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = Element(s + v); + ++i; + } +} + +/// Fills a block of data with sequential elements +template < + typename Element +> +void BlockFillSequentialModN( + Element *ptr, + int64_t capacity, + int64_t mod, + int64_t v = int64_t(1), + int64_t s = int64_t(0)) { + int i = 0; + + while (i < capacity) { + + ptr[i] = static_cast(int32_t(int64_t(s + v) % mod)); + ++i; + } +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb1af995805e3fbcbdbf398ce7191ea2f0dbe8d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_foreach.h @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines several helpers +namespace detail { + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - RankRemaining - 1; + + /// Constructor for general rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + TensorForEachHelper(func, extent, coord); + } + } +}; + +/// Helper to perform for-each operation +template +struct TensorForEachHelper { + + /// Index of the active rank + static int const kActiveRank = Rank - 1; + + /// Constructor for fastest changing rank + TensorForEachHelper( + Func &func, + Coord const &extent, + Coord &coord) { + + for (int i = 0; i < extent.at(kActiveRank); ++i) { + coord[kActiveRank] = i; + func(coord); + } + } +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEach(Coord extent, Func & func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over the index space of a tensor and calls a C++ lambda +template < + typename Func, ///< function applied to each point in a tensor's index space + int Rank> ///< rank of index space +void TensorForEachLambda(Coord extent, Func func) { + Coord coord; + detail::TensorForEachHelper(func, extent, coord); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct BlockForEach { + + /// Constructor performs the operation. + BlockForEach( + Element *ptr, + size_t capacity, + typename Func::Params params = typename Func::Params()) { + + Func func(params); + + for (size_t index = 0; index < capacity; ++index) { + ptr[index] = func(); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..d44dda1f5472f13b7212f7e2e4020e254ff92f88 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_norm.h @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" + +// The contents of this file have been moved to 'tensor_reduce' to cover other types of reductions. + +#include "cutlass/util/reference/host/tensor_reduce.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..887c568059a90f749fc0ac75dd211ce77085a5a9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/util/reference/detail/linear_to_coordinate.h" +#include "cutlass/core_io.h" + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < int64_t(view.size()); ++idx) { + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view.extent()); + + if (view.contains(coord)) { + Element x = view.at(coord); + identity = reduce(identity, transform(x)); + } + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Element, + typename Layout, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorView view_A, + TensorView view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (view_A.extent() != view_B.extent()) { + throw std::runtime_error("Tensor extents must match."); + } + + for (int64_t idx = 0; idx < int64_t(view_A.size()); ++idx) { + + typename Layout::TensorCoord coord; + cutlass::reference::detail::LinearToCoordinate()(coord, idx, view_A.extent()); + + if (view_A.contains(coord)) { + Element a = view_A.at(coord); + Element b = view_B.at(coord); + identity = reduce(identity, transform(a, b)); + } + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSum( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Element, + typename Layout, + typename ComputeType = Element +> +ComputeType TensorSumSq( + TensorView view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNorm( + TensorView view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename Element, + typename Layout, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorView view_A, + TensorView view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea711466df86703aae1702605a928754c9f4e944 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/tensor_reduce.hpp @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Provides several functions for filling tensors with data. +*/ + +#pragma once + +// Standard Library includes +#include +#include +#include + +// Cute includes +#include "cute/tensor.hpp" + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/complex.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/quaternion.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace host { + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tensor reductions +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename Tensor, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + Tensor view, + ComputeType identity, + ReduceOp reduce, + TransformOp transform +) { + + for (int64_t idx = 0; idx < cute::size(view); ++idx) { + identity = reduce(identity, transform(view(idx))); + } + + return identity; +} + +/// Transform-reduce operation over the elements of a tensor. This helper allocates the device-side +/// workspace +template < + typename TensorA, + typename TensorB, + typename ComputeType, + typename ReduceOp, + typename TransformOp +> +ComputeType TensorTransformReduce( + TensorA view_A, + TensorB view_B, + ComputeType identity, + ReduceOp reduce, + TransformOp transform) { + + if (cute::size(view_A) != cute::size(view_B)) { + throw std::runtime_error("Tensor sizes must match."); + } + + for (int64_t idx = 0; idx < cute::size(view_A); ++idx) { + identity = reduce(identity, transform(view_A(idx), view_B(idx))); + } + + return identity; +} + +/// Helper to compute the sum of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSum( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + NumericConverter transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the sum of the squares of the elements of a tensor +template < + typename Tensor, + typename ComputeType = typename Tensor::value_type +> +ComputeType TensorSumSq( + Tensor view, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared transform; + + return TensorTransformReduce( + view, identity, reduce, transform); +} + +/// Helper to compute the norm of the elements of a tensor. +template < + typename Tensor, + typename ComputeType = double +> +ComputeType TensorNorm( + Tensor view, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSq(view, identity)); +} + +/// Helper to compute the sum of the squares of the differences of two tensors +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorSumSqDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + plus reduce; + magnitude_squared_difference transform; + + return TensorTransformReduce( + view_A, view_B, identity, reduce, transform); +} + + +/// Helper to compute the norm of the tensor computed as the difference of two tensors in memory +template < + typename TensorA, + typename TensorB, + typename ComputeType = double +> +ComputeType TensorNormDiff( + TensorA view_A, + TensorB view_B, + ComputeType identity = ComputeType() +) { + + return std::sqrt(TensorSumSqDiff(view_A, view_B, identity)); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass + +/////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h new file mode 100644 index 0000000000000000000000000000000000000000..09b1aff9c0ea9922af46c928a3dd61595be2e4cd --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/arch/mma.h" +#include "cutlass/util/host_tensor.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType compute_a(cast_if_scalar(a)); + ComputeType compute_b(cast_if_scalar(b)); + + accum[i][j] = inner_product_op(compute_a, compute_b, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAdd +> +struct Trmm; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct Trmm { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h new file mode 100644 index 0000000000000000000000000000000000000000..e8db2a4deaf8608882595d68e611f8ae79e134e8 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/reference/host/trmm_complex.h @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Reference implementation for complex-valued TRMM in host-side code. + + +*/ + +#pragma once + +#include "cutlass/blas3.h" +#include "cutlass/complex.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/tensor_view.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/util/reference/host/gemm.h" + +namespace cutlass { +namespace reference { +namespace host { + +/// Computes a Triangular Matrix Multiplication (tensors of rank=2) pointed to by TensorRef +/// objects. +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = multiply_add, + typename ConvertOp = NumericConverter +> +void compute_trmm_complex( + gemm::GemmCoord problem_size, + ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum) { + + static_assert( + LayoutA::kRank == 2 && + LayoutC::kRank == 2, "Tensors must be of rank 2"); + + static_assert(SideModeA != SideMode::kInvalid + , "Side Mode can either be Left or Right."); + + static_assert(FillModeA == FillMode::kLower || FillModeA == FillMode::kUpper + , "Fill Mode can either be Lower or Upper."); + + using CompareOp = typename TrMatrixCompareOp::Type; + + // Note: batch is ignored. + int const M = problem_size.m(); + int const N = problem_size.n(); + // Assuming correct k-dimension value is passed + int const K = problem_size.k(); + + // Blocking necessary to speedup reference implementation + int const Mblock = 16; + int const Nblock = 16; + + ConvertOp convert_op; + InnerProductOp inner_product_op; + CompareOp compare_op; + + for (int row_block = 0; row_block < M; row_block += Mblock) { + for (int col_block = 0; col_block < N; col_block += Nblock) { + + ComputeType accum[Mblock][Nblock]; + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + accum[i][j] = initial_accum; + } + } + + for (int k_block = 0; k_block < K; ++k_block) { + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + if (row < M && col < N) { + ElementA a = ElementA(); + ElementB b = ElementB(); + + if (SideModeA == SideMode::kLeft) { + a = (compare_op(row, k_block)) ? + (tensor_a.at(MatrixCoord(row, k_block))) : ElementA(0); + if (row == k_block && DiagTypeA == DiagType::kUnit) { + a = ElementA(1); + } + b = tensor_b.at(MatrixCoord(k_block, col)); + } else if (SideModeA == SideMode::kRight) { + a = tensor_b.at(MatrixCoord(row, k_block)); + b = (compare_op(k_block, col)) ? + tensor_a.at(MatrixCoord(k_block, col)) : ElementA(0); + if (k_block == col && DiagTypeA == DiagType::kUnit) { + b = ElementA(1); + } + } + + ComputeType a_ik = ComputeType(a); + ComputeType b_kj = ComputeType(b); + + // Conjugate, and hence hermitian, is only allowed for the triangular matrix + if (SideModeA == SideMode::kLeft && TransformA == ComplexTransform::kConjugate) { + a_ik = conj(a_ik); + } else if (SideModeA == SideMode::kRight && TransformA == ComplexTransform::kConjugate) { + b_kj = conj(b_kj); + } + + accum[i][j] = inner_product_op(a_ik, b_kj, accum[i][j]); + } + } + } + } + + for (int j = 0; j < Nblock; j++) { + for (int i = 0; i < Mblock; i++) { + int row = row_block + i; + int col = col_block + j; + + MatrixCoord coord = MatrixCoord(row, col); + + if (row < M && col < N) { + tensor_d.at(coord) = convert_op( + alpha * ScalarType(accum[i][j])); + } + } + } + } + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + ComplexTransform TransformA, + SideMode SideModeA, + FillMode FillModeA, + DiagType DiagTypeA, + typename ElementB, + typename LayoutB, + ComplexTransform TransformB, + typename ElementC, + typename LayoutC, + typename ScalarType, + typename ComputeType, + typename InnerProductOp = cutlass::arch::OpMultiplyAddComplex +> +struct TrmmComplex; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for gaussian multiply-add +template +struct TrmmComplex { + + void operator()(gemm::GemmCoord problem_size, ScalarType alpha, + TensorRef tensor_a, + TensorRef tensor_b, + TensorRef tensor_d, + ComputeType initial_accum = ComputeType(0)) { + static_assert( + LayoutA::kRank == 2 && LayoutC::kRank == 2, + "Tensors must be of rank 2"); + + compute_trmm_complex>( + problem_size, alpha, tensor_a, tensor_b, tensor_d, initial_accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace host +} // namespace reference +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h new file mode 100644 index 0000000000000000000000000000000000000000..0ce1d8a65fdd66ace69f91525b678dd6ad132d24 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/tensor_view_io.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include "cutlass/core_io.h" +#include "cutlass/tensor_view.h" +#include "cutlass/tensor_view_planar_complex.h" +#include "cutlass/complex.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteLeastSignificantRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + out << ScalarIO(view.at(coord)); + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorView_WriteRank( + std::ostream& out, + TensorView const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorView_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by "\n" + if (idx) { + out << ",\n"; + } + TensorView_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << ",\n\n"; + } + TensorView_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +/// Helper to write the least significant rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteLeastSignificantRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (idx) { + out.width(0); + out << ", "; + } + if (idx || coord) { + out.width(width); + } + + complex x = view.at(coord); + out << x; + } + + return out; +} + +/// Helper to write a rank of a TensorView +template < + typename Element, + typename Layout +> +inline std::ostream & TensorViewPlanarComplex_WriteRank( + std::ostream& out, + TensorViewPlanarComplex const& view, + Coord const &start_coord, + int rank, + std::streamsize width) { + + // If called on the least significant rank, write the result as a row + if (rank + 1 == Layout::kRank) { + return TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, start_coord, rank, width); + } + + // Otherwise, write a sequence of rows and newlines + for (int idx = 0; idx < view.extent(rank); ++idx) { + + Coord coord(start_coord); + coord[rank] = idx; + + if (rank + 2 == Layout::kRank) { + // Write least significant ranks asa matrix with rows delimited by ";\n" + if (idx) { + out << ";\n"; + } + TensorViewPlanarComplex_WriteLeastSignificantRank(out, view, coord, rank + 1, width); + } + else { + // Higher ranks are separated by newlines + if (idx) { + out << "\n"; + } + TensorViewPlanarComplex_WriteRank(out, view, coord, rank + 1, width); + } + } + + return out; +} + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorView_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorView const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& TensorViewWrite( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return detail::TensorViewPlanarComplex_WriteRank(out, view, Coord(), 0, out.width()); +} + +/// Prints human-readable representation of a TensorView to an ostream +template < + typename Element, + typename Layout +> +inline std::ostream& operator<<( + std::ostream& out, + TensorViewPlanarComplex const& view) { + + // Prints a TensorView according to the following conventions: + // - least significant rank is printed as rows separated by ";\n" + // - all greater ranks are delimited with newlines + // + // The result is effectively a whitespace-delimited series of 2D matrices. + + return TensorViewWrite(out, view); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h new file mode 100644 index 0000000000000000000000000000000000000000..5dfbfe274dec368cfac291a1c78ece6ffb203c72 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/include/cutlass/util/type_traits.h @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Type traits for common CUDA types +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/numeric_types.h" +#include "cutlass/complex.h" + +namespace cutlass { +struct half_t; + +template +struct TypeTraits { + typedef T host_type; + typedef T device_type; + static inline T remove_negative_zero(T x) { return x; } + static inline T to_print(T x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int8_t host_type; + typedef int8_t device_type; + typedef int8_t integer_type; + typedef uint8_t unsigned_type; + static inline int8_t remove_negative_zero(int8_t x) { return x; } + static inline int to_print(int8_t x) { return (int)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint8_t host_type; + typedef uint8_t device_type; + typedef uint8_t integer_type; + typedef uint8_t unsigned_type; + static inline uint8_t remove_negative_zero(uint8_t x) { return x; } + static inline uint32_t to_print(uint8_t x) { return (uint32_t)x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef int host_type; + typedef int device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline int32_t remove_negative_zero(int32_t x) { return x; } + static inline int to_print(int x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32I; + typedef unsigned host_type; + typedef unsigned device_type; + typedef uint32_t integer_type; + typedef uint32_t unsigned_type; + static inline uint32_t remove_negative_zero(uint32_t x) { return x; } + static inline uint32_t to_print(uint32_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef int64_t host_type; + typedef int64_t device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline int64_t remove_negative_zero(int64_t x) { return x; } + static inline int64_t to_print(int64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_8I; + typedef uint64_t host_type; + typedef uint64_t device_type; + typedef uint64_t integer_type; + typedef uint64_t unsigned_type; + static inline uint64_t remove_negative_zero(uint64_t x) { return x; } + static inline uint64_t to_print(uint64_t x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_16F; + typedef half_t host_type; + typedef half_t device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline half_t remove_negative_zero(half_t x) { + return (x.raw() == 0x8000 ? half_t::bitcast(0) : x); + } + static inline half_t to_print(half_t x) { return x; } + static inline device_type to_device(half_t x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_32F; + typedef float host_type; + typedef float device_type; + typedef int32_t integer_type; + typedef uint32_t unsigned_type; + static inline float remove_negative_zero(float x) { return x == -0.f ? 0.f : x; } + static inline float to_print(float x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +template <> +struct TypeTraits { + static cudaDataType_t const cublas_type = CUDA_R_64F; + typedef double host_type; + typedef double device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + static inline double remove_negative_zero(double x) { return x == -0.0 ? 0.0 : x; } + static inline double to_print(double x) { return x; } + static inline device_type to_device(host_type x) { return x; } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Complex types +// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_16F; + typedef complex host_type; + typedef complex device_type; + typedef int16_t integer_type; + typedef uint16_t unsigned_type; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0_hf ? 0_hf : real(x), + imag(x) == -0_hf ? 0_hf : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + + static cudaDataType_t const cublas_type = CUDA_C_32F; + typedef complex host_type; + typedef complex device_type; + typedef int64_t integer_type; + typedef uint64_t unsigned_type; + + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.f ? 0.f : real(x), + imag(x) == -0.f ? 0.f : imag(x) + ); + } + + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +template <> +struct TypeTraits > { + static cudaDataType_t const cublas_type = CUDA_C_64F; + typedef complex host_type; + typedef complex device_type; + struct integer_type { int64_t real, imag; }; + struct unsigned_type { uint64_t real, imag; }; + static inline complex remove_negative_zero(complex x) { + return complex( + real(x) == -0.0 ? 0.0 : real(x), + imag(x) == -0.0 ? 0.0 : imag(x) + ); + } + static inline complex to_print(complex x) { return x; } + static inline device_type to_device(complex x) { return reinterpret_cast(x); } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..6541ce1b26722ff1f0dba0b4c034067a62f9b96d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/include/third-party/cutlass/tools/util/scripts/split_test_cmake.py @@ -0,0 +1,356 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + + +""" +Given a set of test files to be included in a CMake target, this script extracts +the TEST definitions from each file, writes them into new files, and prints the names +of the new files so that they can be processed as part of a new CMake target. + +For example, given a set of --src_files test_a.cu test_b.cu containing 3 and 2 TEST +definitions, respectively, this script would produce: + test_a_000.cu + test_a_001.cu + test_a_002.cu + test_b_000.cu + test_b_001.cu + +The splitting follows a fairly rudimentary algorithm that does not support all valid C++ programs. +We walk through a given input test file line by line. Any lines that are not within a TEST definition is added to a running +"filler" text. When a TEST definition is encountered, the current filler text becomes the prefix +for that test. All subsequent lines are considered to be part of the TEST definition until the +number of starting function braces ('{') match the number of closing function braces ('}'). When +these counts are equal, the TEST definition is considered to be completed. At this point, we return +to adding lines to the "filler" text until a new TEST definition is encountered. Any "filler" text +following a TEST definition is added to the suffix of that TEST definition (this is useful for finishing +off #if statements, as is common in unit tests.). + +A state machine illustrating this algorithm at a high level is provided in the source below. + +Example: Suppose an input test `test.cu` has the following source: + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +The contents of the two resulting test files will be: + $ cat test_000.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + TEST(SM90_a, 256x128x64_2x2x1) { + std::cout << "Test #1" << std::endl; + } + + // Test #2 + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + $ cat test_001.cu + // COPYRIGHT + #include + + #if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // Test #1 + + // Test #2 + TEST(SM90_b, 256x128x64_1x1x1) { + std::cout << "Test #2" << std::endl; + } + + #endif defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +Notice that each of test_000.cu and test_001.cu contain comments that appear outside +the TEST definitions not included in each file. This is by design, as these +would be considered "filler" text. + +As expected, some cases can't be handled. Below is a non-exhaustive list: + 1. New TEST following the closing '}' of a TEST case on the same line: + TEST(x, y) { + // Do stuff + } TEST(a, b) { + + In this case, "TEST(a, b) {" will be ignored + + 2. Preprocessor macros that occur midway through a test case and extend + beyond the conclusion of a testcase + + Example: + TEST(a, b) { + // Do stuff + #if X + // Do more stuff + } + #else + // Do other stuff + } + #endif +""" + + +import argparse +import enum +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("cmake_target", type=str, + help="Name of the CMake target being generated.") +parser.add_argument("src_dir", type=str, + help="Path to the directory containing test files.") +parser.add_argument("--src_files", nargs='+', + help="Files containing TEST instances to split.") +parser.add_argument("--max_tests_per_file", type=int, default=1, + help="Maximum number of TEST instances per file.") +parser.add_argument("--dst_dir", type=str, + help="Path to the directory to which to write new test files. If not set, uses src_dir.") +args = parser.parse_args() + + +if args.dst_dir == None: + args.dst_dir = args.src_dir + + +class Testcase: + """ + Lightweight tracker of test-case processing status + """ + def __init__(self, prefix_text): + # Any text that preceded the TEST definition that was + # not part of another TEST definition + self.prefix = prefix_text + + # Any text within the TEST definition + self.test = "" + + # Any text that follows the completion of the TEST definition + # and is not included in other TEST definitions + self.suffix = "" + + # Whether the test's definition has concluded + self.completed = False + + # Current balance of opening and closing curly brackets in + # the TEST definition. '{' increments the count and '}' decrements it. + # A value of 0 (when self.completed == False) indicates that the test + # has completed. + self.curly_bracket_balance = 0 + + +class ParseState(enum.Enum): + """ + State machine for processing. + Transitions occur on each line encountered in the soruce file + + + Line does not contain 'TEST(' + +----+ + | | + | v 'TEST(' + +--------+ encountered +--------------------------+ + ------>| Filler | -----------------------> | TestDeclaredWaitingStart | + +--------+ +--------------------------+ + ^ | + Number of '{' | | First '{' encountered + equals number of | +--------+ | + '}' encountered +-----------| InTest | <------------------+ + +--------+ + | ^ + | | + +----+ + Number of '{' encountered + exceeds number of '}' encountered + """ + + + # Any text that is not part of a TEST case + Filler = 0 + + # Processing text within the first { of the TEST case + # and before the en of the final } of the TEST case + InTest = 1 + + # Processing text from the start of the TEST definition + # but before the first {. This could occur if the opening { + # occurs on a separate line than the TEST definition. + TestDeclaredWaitingStart = 2 + + +cmake_src_list = [] +for filename in args.src_files: + if '.' not in filename: + # Add any non-filename arguments to the command list by default + cmake_src_list.append(filename) + continue + + if '/' in filename: + raise Exception( + f"Source files passed to {__file__} must be within the same directory " + "as the CMakeLists defining the target using the files. " + f"Provided path {filename} is in a different directory.") + + full_filename = os.path.join(args.src_dir, filename) + with open(full_filename, 'r') as infile: + lines = infile.readlines() + + # Find the number of instances of "TEST(" + ntest = sum([1 for line in lines if "TEST(" in line]) + + if ntest <= args.max_tests_per_file: + # File contains fewer than max_tests_per_file TEST instances. It does + # not need to be split + cmake_src_list.append(filename) + continue + + # Current state of the parsing state machine. We start with filler text + state = ParseState.Filler + + # List of individual TESTs found + tests = [] + + # Ongoing text that is not included in a TEST definition. This will serve + # as the prefix for any yet-to-be encountered TEST definitions. + filler_text = "" + + def add_filler_text(text): + global filler_text + # Add new text to the ongoing filler text and to the suffixes of + # any completed tests + filler_text += text + for i in range(len(tests)): + if tests[i].completed: + tests[i].suffix += text + + for line in lines: + if state == ParseState.Filler: + # We are not currently within a TEST definition. + + if 'TEST(' in line: + # We have encountered a new TEST( case. Any text preceding this + # must be added to the filler text (e.g., if we have a line of the form: + # "static constexpr int Val = 4; TEST(blah) {" + # then "static constexpr int Val = 4;" needs to be included in filler + # text, as it could be used by subsequent tests.) + splits = line.split('TEST') + + # There should not be more than one TEST definition on a given line + assert len(splits) <= 2 + + if len(splits) > 1: + if not splits[0].isspace(): + # Only add text to filler if there are non-whitespace charcters + # preceding the TEST definition in the line + filler_text += splits[0] + + # The new line is just the TEST-related line + line = 'TEST' + splits[-1] + + # Add tests and transtion to TestDeclaredWaitingStart state. + # Do not add the line to the test text of the new test case; this + # will be done in either the TestDeclaredWaitingStart state processing + # below or in the InTest state processing below. + tests.append(Testcase(filler_text)) + state = ParseState.TestDeclaredWaitingStart + else: + # Any remaining filler text is added to the running filler_text + # which will be used as the prefix for any new tests, and to the + # suffix of any completed tests + add_filler_text(line) + + if state == ParseState.TestDeclaredWaitingStart: + # We have seen a TEST definition but have not yet seen its opening {. + + if '{' in line: + # The first curly bracket for the TEST definition has been found. + # Advance to state InTests. Do not add the line to the test's text + # or change the curly-brace balance of the test; these will be done + # when processing the state == ParseState.InTest condition below. + state = ParseState.InTest + else: + tests[-1].test += line + + if state == ParseState.InTest: + # We are currently within a TEST definition. + # Process lines character-by-character looking for opening and closing + # braces. If we reach parity between opening and closing braces, the + # test is considered done. + filler_text_to_add = "" + for char in line: + if not tests[-1].completed: + tests[-1].test += char + if char == '{': + tests[-1].curly_bracket_balance += 1 + elif char == '}': + tests[-1].curly_bracket_balance -= 1 + if tests[-1].curly_bracket_balance == 0: + tests[-1].completed = True + else: + filler_text_to_add += char + + if filler_text_to_add != "" and (not filler_text_to_add.isspace() or '\n' in filler_text_to_add): + add_filler_text('\n' + filler_text_to_add) + + if tests[-1].completed: + state = ParseState.Filler + + # Write out the new files for tests + filename_prefix, filename_suffix = filename.split('.') + for i, test in enumerate(tests): + assert test.completed + new_filename = filename_prefix + '_' + str(i).zfill(3) + '.' + filename_suffix + full_new_filename = os.path.join(args.dst_dir, new_filename) + + # Replace any '\' with '/'. CMake doesn't like '\'. + full_new_filename = full_new_filename.replace('\\', '/') + + with open(full_new_filename, 'w') as outfile: + outfile.write(test.prefix + test.test + test.suffix) + cmake_src_list.append(full_new_filename) + + +for cmake_file in cmake_src_list: + print(cmake_file) diff --git a/build/torch210-cxx11-cu130-aarch64-linux/metadata.json b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json new file mode 100644 index 0000000000000000000000000000000000000000..4899badb63d45293425e2164944268b6058af95d --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/metadata.json @@ -0,0 +1,11 @@ +{ + "version": 1, + "license": "MIT", + "python-depends": [], + "backend": { + "type": "cuda", + "archs": [ + "9.0a" + ] + } +} diff --git a/build/torch210-cxx11-cu130-aarch64-linux/testing/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13a9d78dea58a6492183f9ddc50f1510a679cbe6 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/testing/__init__.py @@ -0,0 +1,4 @@ +from . import bench, numeric, utils +from .bench import * +from .numeric import * +from .utils import * diff --git a/build/torch210-cxx11-cu130-aarch64-linux/testing/bench.py b/build/torch210-cxx11-cu130-aarch64-linux/testing/bench.py new file mode 100644 index 0000000000000000000000000000000000000000..2c752da2d3bb0aba7e03ef1921428432b396917a --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/testing/bench.py @@ -0,0 +1,137 @@ +import os +import sys +import torch + + +def bench(fn, num_warmups: int = 5, num_tests: int = 10, + high_precision: bool = False): + # Flush L2 cache with 256 MB data + torch.cuda.synchronize() + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda') + cache.zero_() + + # Warmup + for _ in range(num_warmups): + fn() + + # Add a large kernel to eliminate the CPU launch overhead + if high_precision: + x = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + y = torch.randn((8192, 8192), dtype=torch.float, device='cuda') + x @ y + + # Testing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for i in range(num_tests): + fn() + end_event.record() + torch.cuda.synchronize() + + return start_event.elapsed_time(end_event) / num_tests / 1e3 + + +class empty_suppress: + def __enter__(self): + return self + + def __exit__(self, *_): + pass + + +class suppress_stdout_stderr: + def __enter__(self): + self.outnull_file = open(os.devnull, 'w') + self.errnull_file = open(os.devnull, 'w') + + self.old_stdout_fileno_undup = sys.stdout.fileno() + self.old_stderr_fileno_undup = sys.stderr.fileno() + + self.old_stdout_fileno = os.dup(sys.stdout.fileno()) + self.old_stderr_fileno = os.dup(sys.stderr.fileno()) + + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup) + os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup) + + sys.stdout = self.outnull_file + sys.stderr = self.errnull_file + return self + + def __exit__(self, *_): + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup) + os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup) + + os.close(self.old_stdout_fileno) + os.close(self.old_stderr_fileno) + + self.outnull_file.close() + self.errnull_file.close() + + +def bench_kineto(fn, kernel_names, num_tests: int = 30, + suppress_kineto_output: bool = False, + trace_path: str = None, flush_l2: bool = True, + with_multiple_kernels: bool = False): + assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple) + is_tuple = isinstance(kernel_names, tuple) + + # Skip profiling + # Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer + if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)): + return (1, ) * len(kernel_names) if is_tuple else 1 + + # By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle + flush_l2_size = int(8e9 // 4) + + # For some auto-tuning kernels with prints + fn() + + # Profile + suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress + with suppress(): + schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) + profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) + with profiler: + for i in range(2): + for _ in range(num_tests): + if flush_l2: + torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_() + fn() + profiler.step() + + # Parse the profiling table + prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n') + kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names + if not with_multiple_kernels: + for name in kernel_names: + assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table' + + # Save chrome traces + if trace_path is not None: + profiler.export_chrome_trace(trace_path) + + # Return average kernel times + units = {'ms': 1e3, 'us': 1e6} + kernel_times = [] + for name in kernel_names: + total_time = 0 + total_num = 0 + for line in prof_lines: + if name in line: + time_str = line.split()[-2] + num_str = line.split()[-1] + for unit, scale in units.items(): + if unit in time_str: + total_time += float(time_str.replace(unit, '')) / scale * int(num_str) + total_num += int(num_str) + break + kernel_times.append(total_time / total_num if total_num > 0 else 0) + + return tuple(kernel_times) if is_tuple else kernel_times[0] diff --git a/build/torch210-cxx11-cu130-aarch64-linux/testing/numeric.py b/build/torch210-cxx11-cu130-aarch64-linux/testing/numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..a42c4318db47593c47a4ea89fbdbcb1ffb5cd30e --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/testing/numeric.py @@ -0,0 +1,21 @@ +import torch +from typing import Iterable + + +def calc_diff(x: torch.Tensor, y: torch.Tensor): + x, y = x.double(), y.double() + denominator = (x * x + y * y).sum() + if denominator == 0: # Which means that all elements in x and y are 0 + return 0.0 + sim = 2 * (x * y).sum() / denominator + return 1 - sim + + +def count_bytes(*tensors): + total = 0 + for t in tensors: + if isinstance(t, (tuple, list)): + total += count_bytes(*t) + elif t is not None: + total += t.numel() * t.element_size() + return total diff --git a/build/torch210-cxx11-cu130-aarch64-linux/testing/utils.py b/build/torch210-cxx11-cu130-aarch64-linux/testing/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2d202d4192ed385f986ac5cc216acc69378d8ea9 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/testing/utils.py @@ -0,0 +1,38 @@ +import functools +import os +import torch +from typing import Callable + +def get_arch_major() -> int: + major, minor = torch.cuda.get_device_capability() + return major + + +def test_filter(condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + func(*args, **kwargs) + else: + print(f'{func.__name__}:') + print(f' > Filtered by {condition}') + print() + return wrapper + return decorator + + +def ignore_env(name: str, condition: Callable): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if condition(): + saved = os.environ.pop(name, None) + func(*args, **kwargs) + if saved is not None: + os.environ[name] = saved + else: + func(*args, **kwargs) + + return wrapper + return decorator diff --git a/build/torch210-cxx11-cu130-aarch64-linux/utils/__init__.py b/build/torch210-cxx11-cu130-aarch64-linux/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f859a20726fcc0ea32c54ed8df37b19b3960a4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/utils/__init__.py @@ -0,0 +1,3 @@ +from . import math, layout +from .layout import * +from .math import * diff --git a/build/torch210-cxx11-cu130-aarch64-linux/utils/layout.py b/build/torch210-cxx11-cu130-aarch64-linux/utils/layout.py new file mode 100644 index 0000000000000000000000000000000000000000..a6bc29d9aaae296a83b8c3546b832a083ade6b28 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/utils/layout.py @@ -0,0 +1,25 @@ +from .._ops import ops + + +def get_mk_alignment_for_contiguous_layout(): + return ops.get_mk_alignment_for_contiguous_layout() + + +def get_tma_aligned_size(mn: int, element_size: int): + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks) + + +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout diff --git a/build/torch210-cxx11-cu130-aarch64-linux/utils/math.py b/build/torch210-cxx11-cu130-aarch64-linux/utils/math.py new file mode 100644 index 0000000000000000000000000000000000000000..c65026e54b87faf34b498d14d3f81a94759615f4 --- /dev/null +++ b/build/torch210-cxx11-cu130-aarch64-linux/utils/math.py @@ -0,0 +1,107 @@ +import torch +from typing import Tuple + + +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_token_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + padded_n = align(n, gran_k) + x_padded = torch.empty((m, padded_n), dtype=x.dtype, device=x.device).fill_(0) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn).view(m, padded_n)[:, :n].contiguous(), sf + + +def per_channel_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 and x.size(0) % gran_k == 0 + m, n = x.shape + x_view = x.view(-1, gran_k, n) + x_amax = x_view.abs().float().amax(dim=1).view(-1, n).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + return (x_view * (1.0 / sf.unsqueeze(1))).to(torch.float8_e4m3fn).view(m, n), sf + + +def per_block_cast_to_fp8(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros((align(m, gran_k), align(n, gran_k)), dtype=x.dtype, device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, gran_k, x_padded.size(1) // gran_k, gran_k) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x_view * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled.view_as(x_padded)[:m, :n].contiguous(), sf.view(x_view.size(0), x_view.size(2)) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple, use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _quantize_to_fp4_e2m1(x: torch.Tensor) -> torch.Tensor: + ax = x.abs().clamp_max(6.0) + # {0, 0.5, 1, 1.5, 2, 3, 4, 6} + # midpoints: 0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0 + boundaries = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0], + device=x.device, dtype=ax.dtype) + idx = torch.bucketize(ax, boundaries) + code = idx.to(torch.uint8) + sign = (x < 0) & (idx != 0) + code = code | (sign.to(torch.uint8) << 3) + return code # uint8, 0..15 + + +def per_token_cast_to_fp4(x: torch.Tensor, use_ue8m0: bool, gran_k: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + assert n % 2 == 0 + padded_n = align(n, gran_k) + x_padded = torch.zeros((m, padded_n), dtype=x.dtype, device=x.device) + x_padded[:, :n] = x + x_view = x_padded.view(m, -1, gran_k) + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = x_amax / 6.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = x_view * (1.0 / sf.unsqueeze(2)) + codes = _quantize_to_fp4_e2m1(x_scaled).view(m, padded_n) # uint8, (m, padded_n) + codes2 = codes.view(m, padded_n // 2, 2) + packed = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) # uint8 + return packed[:, :n // 2].contiguous(), sf + + +def transpose_packed_fp4(a: torch.Tensor) -> torch.Tensor: + assert a.dtype == torch.uint8 + assert a.dim() == 2 + m, n2 = a.shape + n = n2 * 2 + assert (m % 2) == 0 + lo = a & 0x0F + hi = (a >> 4) & 0x0F + codes = torch.empty((m, n), device=a.device, dtype=torch.uint8) + codes[:, 0::2], codes[:, 1::2] = lo, hi + codes_t = codes.transpose(0, 1).contiguous() + codes2 = codes_t.view(n, m // 2, 2) + out = (codes2[:, :, 0] & 0x0F) | ((codes2[:, :, 1] & 0x0F) << 4) + return out.contiguous() \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-aarch64-linux/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9cd2069a1bc6bcd76e84fbe093d7f16d394efa05 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/__init__.py @@ -0,0 +1,684 @@ +import os +import subprocess +import torch + +# Import the compiled extension +from ._ops import ops +from . import utils + +__version__ = "2.3.0" + + +# Runtime + + +def set_num_sms(num_sms: int): + ops.set_num_sms(num_sms) + + +def get_num_sms() -> int: + return ops.get_num_sms() + + +def set_tc_util(tc_util: int): + ops.set_tc_util(tc_util) + + +def get_tc_util() -> int: + return ops.get_tc_util() + + +def get_mk_alignment_for_contiguous_layout() -> int: + return ops.get_mk_alignment_for_contiguous_layout() + + +# Layout utilities + + +def get_tma_aligned_size(mn: int, element_size: int) -> int: + return ops.get_tma_aligned_size(mn, element_size).item() + + +def get_mn_major_tma_aligned_tensor(sf): + return ops.get_mn_major_tma_aligned_tensor(sf) + + +def get_mn_major_tma_aligned_packed_ue8m0_tensor(sf): + return ops.get_mn_major_tma_aligned_packed_ue8m0_tensor(sf) + + +def get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor(sf, ks_tensor, ks): + ks_int = torch.tensor(ks, dtype=torch.int32, device="cpu") + return ops.get_k_grouped_mn_major_tma_aligned_packed_ue8m0_tensor( + sf, ks_tensor, ks_int + ) + + +def transform_sf_into_required_layout( + sf, + mn, + k, + recipe=None, + recipe_ab=None, + num_groups=None, + is_sfa=False, + disable_ue8m0_cast=False, +): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_recipe_ab = recipe_ab is not None + rab0, rab1 = recipe_ab if has_recipe_ab else (0, 0) + has_ng = num_groups is not None + ng = num_groups if has_ng else 0 + return ops.transform_sf_into_required_layout( + sf, + mn, + k, + r0, + r1, + r2, + has_recipe, + rab0, + rab1, + has_recipe_ab, + ng, + has_ng, + is_sfa, + disable_ue8m0_cast, + ) + + +# Aliases for contiguous layout alignment +get_m_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout +get_k_alignment_for_contiguous_layout = get_mk_alignment_for_contiguous_layout + + +# Helper to flatten recipe args + + +def _flatten_recipe(recipe, recipe_a=None, recipe_b=None): + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + has_ra = recipe_a is not None + ra0, ra1 = recipe_a if has_ra else (0, 0) + has_rb = recipe_b is not None + rb0, rb1 = recipe_b if has_rb else (0, 0) + return r0, r1, r2, has_recipe, ra0, ra1, has_ra, rb0, rb1, has_rb + + +# FP8/FP4 GEMM ops + + +def fp8_fp4_gemm_nt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_nn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_nn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tn( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tn( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_fp4_gemm_tt( + a, + b, + d, + c=None, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="mn", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.fp8_fp4_gemm_tt( + a_data, + a_sf, + b_data, + b_sf, + d, + c, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# FP8 aliases (same as FP8/FP4) +fp8_gemm_nt = fp8_fp4_gemm_nt +fp8_gemm_nn = fp8_fp4_gemm_nn +fp8_gemm_tn = fp8_fp4_gemm_tn +fp8_gemm_tt = fp8_fp4_gemm_tt + + +# M-grouped FP8/FP4 GEMM ops + + +def m_grouped_fp8_fp4_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_fp8_fp4_gemm_nt_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + em, + has_em, + ) + + +def m_grouped_fp8_fp4_gemm_nn_contiguous( + a, + b, + d, + grouped_layout, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, + use_psum_layout=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nn_contiguous( + a_data, + a_sf, + b_data, + b_sf, + d, + grouped_layout, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + use_psum_layout, + ) + + +def m_grouped_fp8_fp4_gemm_nt_masked( + a, + b, + d, + masked_m, + expected_m, + recipe=None, + recipe_a=None, + recipe_b=None, + compiled_dims="nk", + disable_ue8m0_cast=False, +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2, hr, ra0, ra1, hra, rb0, rb1, hrb = _flatten_recipe( + recipe, recipe_a, recipe_b + ) + ops.m_grouped_fp8_fp4_gemm_nt_masked( + a_data, + a_sf, + b_data, + b_sf, + d, + masked_m, + expected_m, + r0, + r1, + r2, + hr, + ra0, + ra1, + hra, + rb0, + rb1, + hrb, + compiled_dims, + disable_ue8m0_cast, + ) + + +# M-grouped FP8 aliases +m_grouped_fp8_gemm_nt_contiguous = m_grouped_fp8_fp4_gemm_nt_contiguous +m_grouped_fp8_gemm_nn_contiguous = m_grouped_fp8_fp4_gemm_nn_contiguous +m_grouped_fp8_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + +# Legacy aliases +fp8_m_grouped_gemm_nt_masked = m_grouped_fp8_fp4_gemm_nt_masked + + +# K-grouped FP8 GEMM ops + + +def k_grouped_fp8_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_tn_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +def k_grouped_fp8_gemm_nt_contiguous( + a, b, d, ks, ks_tensor, c=None, recipe=(1, 1, 128), compiled_dims="mn" +): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.k_grouped_fp8_gemm_nt_contiguous( + a_data, a_sf, b_data, b_sf, d, ks_tensor, c, r0, r1, r2, compiled_dims + ) + + +# BF16 GEMM ops + + +def bf16_gemm_nt(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nt(a, b, d, c, compiled_dims) + + +def bf16_gemm_nn(a, b, d, c=None, compiled_dims="nk"): + ops.bf16_gemm_nn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tn(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tn(a, b, d, c, compiled_dims) + + +def bf16_gemm_tt(a, b, d, c=None, compiled_dims="mn"): + ops.bf16_gemm_tt(a, b, d, c, compiled_dims) + + +# M-grouped BF16 GEMM ops + + +def m_grouped_bf16_gemm_nt_contiguous( + a, + b, + d, + grouped_layout, + compiled_dims="nk", + use_psum_layout=False, + expected_m_for_psum_layout=None, +): + has_em = expected_m_for_psum_layout is not None + em = expected_m_for_psum_layout if has_em else 0 + ops.m_grouped_bf16_gemm_nt_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout, em, has_em + ) + + +def m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims="nk", use_psum_layout=False +): + ops.m_grouped_bf16_gemm_nn_contiguous( + a, b, d, grouped_layout, compiled_dims, use_psum_layout + ) + + +def m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims="nk"): + ops.m_grouped_bf16_gemm_nt_masked(a, b, d, masked_m, expected_m, compiled_dims) + + +# Legacy alias +bf16_m_grouped_gemm_nt_masked = m_grouped_bf16_gemm_nt_masked + + +# K-grouped BF16 GEMM ops + + +def k_grouped_bf16_gemm_tn_contiguous( + a, b, d, ks, ks_tensor, c=None, compiled_dims="mn" +): + ops.k_grouped_bf16_gemm_tn_contiguous(a, b, d, ks_tensor, c, compiled_dims) + + +# cuBLASLt GEMM ops + + +def cublaslt_gemm_nt(a, b, d, c=None): + ops.cublaslt_gemm_nt(a, b, d, c) + + +def cublaslt_gemm_nn(a, b, d, c=None): + ops.cublaslt_gemm_nn(a, b, d, c) + + +def cublaslt_gemm_tn(a, b, d, c=None): + ops.cublaslt_gemm_tn(a, b, d, c) + + +def cublaslt_gemm_tt(a, b, d, c=None): + ops.cublaslt_gemm_tt(a, b, d, c) + + +# Attention ops + + +def fp8_gemm_nt_skip_head_mid( + a, b, d, head_splits, recipe=None, compiled_dims="nk", disable_ue8m0_cast=False +): + a_data, a_sf = a + b_data, b_sf = b + left, mid, right = head_splits + has_recipe = recipe is not None + r0, r1, r2 = recipe if has_recipe else (0, 0, 0) + ops.fp8_gemm_nt_skip_head_mid( + a_data, + a_sf, + b_data, + b_sf, + d, + left, + mid, + right, + r0, + r1, + r2, + has_recipe, + compiled_dims, + disable_ue8m0_cast, + ) + + +def fp8_mqa_logits( + q, + kv, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits=True, + max_seqlen_k=0, +): + kv_data, kv_sf = kv + return ops.fp8_mqa_logits( + q, + kv_data, + kv_sf, + weights, + cu_seq_len_k_start, + cu_seq_len_k_end, + clean_logits, + max_seqlen_k, + ) + + +def get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms): + return ops.get_paged_mqa_logits_metadata(context_lens, block_kv, num_sms) + + +def fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits=False, +): + return ops.fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_table, + schedule_meta, + max_context_len, + clean_logits, + ) + + +# Einsum ops + + +def einsum(expr, a, b, d, c=None, use_cublaslt=False): + ops.einsum(expr, a, b, d, c, use_cublaslt) + + +def fp8_einsum(expr, a, b, d, c=None, recipe=(1, 128, 128)): + a_data, a_sf = a + b_data, b_sf = b + r0, r1, r2 = recipe + ops.fp8_einsum(expr, a_data, a_sf, b_data, b_sf, d, c, r0, r1, r2) + + +# Hyperconnection ops + + +def tf32_hc_prenorm_gemm(a, b, d, sqr_sum, num_splits=None): + has_ns = num_splits is not None + ns = num_splits if has_ns else 0 + ops.tf32_hc_prenorm_gemm(a, b, d, sqr_sum, ns, has_ns) + + +# Initialize the C++ runtime + + +def _find_cuda_home() -> str: + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") + if cuda_home is None: + try: + with open(os.devnull, "w") as devnull: + nvcc = ( + subprocess.check_output(["which", "nvcc"], stderr=devnull) + .decode() + .rstrip("\r\n") + ) + cuda_home = os.path.dirname(os.path.dirname(nvcc)) + except Exception: + cuda_home = "/usr/local/cuda" + if not os.path.exists(cuda_home): + cuda_home = None + assert cuda_home is not None, "Could not find CUDA installation" + return cuda_home + + +# Find the library root for JIT headers +# In development: use the repo's deep_gemm/ directory +# In installed wheel: use this package's directory +_lib_root = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "deep_gemm" +) +if not os.path.isdir(os.path.join(_lib_root, "include")): + # Fallback: try the parent package + _lib_root = os.path.dirname(os.path.abspath(__file__)) + +_initialized = False + +# Set DG_CUTLASS_INCLUDE for JIT kernel compilation (if not already set by user) +if "DG_CUTLASS_INCLUDE" not in os.environ: + _include = os.path.join(_lib_root, "include") + _cutlass_include_candidates = [ + _include, # legacy layout: include/cutlass + os.path.join(_include, "third-party", "cutlass", "include"), # submodule layout + ] + for _cutlass_include in _cutlass_include_candidates: + if os.path.isdir(os.path.join(_cutlass_include, "cutlass")): + os.environ["DG_CUTLASS_INCLUDE"] = _cutlass_include + break + else: + # Fall back to nvidia-cutlass pip package + try: + import nvidia.cutlass as _nc + os.environ["DG_CUTLASS_INCLUDE"] = os.path.join( + os.path.dirname(_nc.__file__), "include" + ) + except ImportError: + pass + +def _ensure_initialized(): + global _initialized + if _initialized: + return + _initialized = True + ops.init(_lib_root, _find_cuda_home()) + + +# Try to initialize eagerly, but don't fail if CUDA is not found +# (e.g., during build-time import checks). init() will be called +# lazily on first actual kernel use. +try: + _ensure_initialized() +except (AssertionError, RuntimeError): + pass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so b/build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so new file mode 100644 index 0000000000000000000000000000000000000000..5a97e4a25f6189b2cdd8d3a3f542e5c8f2173510 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_deep_gemm_cuda_a68a39f.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7e727962d2e9780eabead74d4b8d99037f9a4402802a824b15fe0136ec5ac94d +size 2825840 diff --git a/build/torch29-cxx11-cu126-aarch64-linux/_ops.py b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..a1598d972cf45498451b415d7a5869693e9b4a96 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _deep_gemm_cuda_a68a39f +ops = torch.ops._deep_gemm_cuda_a68a39f + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_deep_gemm_cuda_a68a39f::{op_name}" diff --git a/build/torch29-cxx11-cu126-aarch64-linux/deep_gemm/__init__.py b/build/torch29-cxx11-cu126-aarch64-linux/deep_gemm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03dbc1afe1cf156661a2b1b22003cd5f599a0309 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/deep_gemm/__init__.py @@ -0,0 +1,26 @@ +import ctypes +import sys + +import importlib +from pathlib import Path +from types import ModuleType + +def _import_from_path(file_path: Path) -> ModuleType: + # We cannot use the module name as-is, after adding it to `sys.modules`, + # it would also be used for other imports. So, we make a module name that + # depends on the path for it to be unique using the hex-encoded hash of + # the path. + path_hash = "{:x}".format(ctypes.c_size_t(hash(file_path.absolute())).value) + module_name = path_hash + spec = importlib.util.spec_from_file_location(module_name, file_path) + if spec is None: + raise ImportError(f"Cannot load spec for {module_name} from {file_path}") + module = importlib.util.module_from_spec(spec) + if module is None: + raise ImportError(f"Cannot load module {module_name} from spec") + sys.modules[module_name] = module + spec.loader.exec_module(module) # type: ignore + return module + + +globals().update(vars(_import_from_path(Path(__file__).parent.parent / "__init__.py"))) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/cute_tie.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/cute_tie.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cd2aace7a8b8dd642f4c149bfc974c3d21e5f5b5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/cute_tie.cuh @@ -0,0 +1,48 @@ +#pragma once + +namespace cute { + +struct ignore_t { + template + constexpr const ignore_t& operator=(T&&) const noexcept { + return *this; + } +}; + +inline constexpr ignore_t ignore{}; + +} // namespace cute + +#define CUTE_TIE_CONCAT_IMPL(A, B) A##B +#define CUTE_TIE_CONCAT(A, B) CUTE_TIE_CONCAT_IMPL(A, B) + +#define CUTE_TIE_GET_NTH_ARG(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, N, ...) N +#define CUTE_TIE_COUNT_ARGS(...) \ + CUTE_TIE_GET_NTH_ARG(__VA_ARGS__, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0) + +#define CUTE_TIE_OP_DECL(I, TUPLE, VAR) auto VAR = ::cute::get(TUPLE) +#define CUTE_TIE_OP_ASSIGN(I, TUPLE, VAR) VAR = ::cute::get(TUPLE) + +#define CUTE_TIE_APPLY_OP_1(OP, T, V1) OP(0, T, V1); +#define CUTE_TIE_APPLY_OP_2(OP, T, V1, V2) OP(0, T, V1); OP(1, T, V2); +#define CUTE_TIE_APPLY_OP_3(OP, T, V1, V2, V3) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); +#define CUTE_TIE_APPLY_OP_4(OP, T, V1, V2, V3, V4) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); +#define CUTE_TIE_APPLY_OP_5(OP, T, V1, V2, V3, V4, V5) OP(0, T, V1); OP(1, T, V2); OP(2, T, V3); OP(3, T, V4); OP(4, T, V5); + +#define CUTE_TIE_DECL(TUPLE_EXPR, ...) \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_DECL, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ) + +#define CUTE_TIE(TUPLE_EXPR, ...) \ + do { \ + auto&& CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__) = (TUPLE_EXPR); \ + CUTE_TIE_CONCAT(CUTE_TIE_APPLY_OP_, CUTE_TIE_COUNT_ARGS(__VA_ARGS__)) ( \ + CUTE_TIE_OP_ASSIGN, \ + CUTE_TIE_CONCAT(cute_tie__temp_tuple_, __LINE__), \ + __VA_ARGS__ \ + ); \ + } while (0) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..5f6a7a1956b64771ee5294035bead6bc2e3dd850 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/epilogue_utils.cuh @@ -0,0 +1,27 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +struct EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + return n_idx; + } +}; + +template +struct EpilogueHeadSplits: EpilogueIdentity { + template + __device__ __forceinline__ static uint32_t apply_index_n(const uint32_t &n_idx) { + DG_STATIC_ASSERT(kLeft % STORE_BLOCK_N == 0 and kMid % STORE_BLOCK_N == 0 + and kRight % STORE_BLOCK_N == 0, "Invalid head splits config"); + return n_idx + (n_idx + kRight) / (kLeft + kRight) * kMid; + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/reduction.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/reduction.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d9e35f73128427d4c842c5382798c3866c27124f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/reduction.cuh @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include + +#include + +// Operation functors +template struct ReduceSum { __device__ T operator()(T a, T b) const { return a + b; } }; +template struct ReduceMax { __device__ T operator()(T a, T b) const { return a > b ? a : b; } }; +template struct ReduceMin { __device__ T operator()(T a, T b) const { return a < b ? a : b; } }; +template struct ReduceAnd { __device__ T operator()(T a, T b) const { return a & b; } }; +template struct ReduceOr { __device__ T operator()(T a, T b) const { return a | b; } }; + +// Unified reduction function +template +__forceinline__ __device__ T warp_reduce(T value, Op op) { + DG_STATIC_ASSERT(kNumLanesPerGroup == 32 or kNumLanesPerGroup == 16 or kNumLanesPerGroup == 8 or + kNumLanesPerGroup == 4 or kNumLanesPerGroup == 2 or kNumLanesPerGroup == 1, + "Invalid number of lanes"); + constexpr uint32_t mask = 0xffffffff; + if constexpr (kIntergroupReduce) { + if constexpr (kNumLanesPerGroup <= 1) value = op(value, __shfl_xor_sync(mask, value, 1)); + if constexpr (kNumLanesPerGroup <= 2) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup <= 4) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup <= 8) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup <= 16) value = op(value, __shfl_xor_sync(mask, value, 16)); + } else { + if constexpr (kNumLanesPerGroup >= 32) value = op(value, __shfl_xor_sync(mask, value, 16)); + if constexpr (kNumLanesPerGroup >= 16) value = op(value, __shfl_xor_sync(mask, value, 8)); + if constexpr (kNumLanesPerGroup >= 8) value = op(value, __shfl_xor_sync(mask, value, 4)); + if constexpr (kNumLanesPerGroup >= 4) value = op(value, __shfl_xor_sync(mask, value, 2)); + if constexpr (kNumLanesPerGroup >= 2) value = op(value, __shfl_xor_sync(mask, value, 1)); + } + return value; +} + +// Convenience aliases +template +__forceinline__ __device__ T warp_reduce_sum(T value) { + return warp_reduce(value, ReduceSum{}); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/scheduler.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/scheduler.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f93b96ee64b934cd24e2567dd2b6b54be913408b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/scheduler.cuh @@ -0,0 +1,288 @@ +#pragma once + +#include +#include + +namespace deep_gemm { + +enum class IndexType { + MN, + K, + SF_K, +}; + +template +static constexpr uint32_t get_num_1d_blocks_per_group() { + // Select the best from candidates + uint32_t num_best_blocks = 0, min_usage = cute::numeric_limits::max(); + for (const auto& candidate: {8u, 16u}) { + const auto& usage = kIsMulticastOnA ? + candidate * BLOCK_N + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_M: // Grouping on N + candidate * BLOCK_M + constexpr_ceil_div(kNumSMs, candidate) * BLOCK_N; // Grouping on M + if (usage < min_usage) + min_usage = usage, num_best_blocks = candidate; + } + return num_best_blocks; +} + +#pragma clang diagnostic push +#pragma ide diagnostic ignored "cppcoreguidelines-pro-type-member-init" +template ()> +struct Scheduler { + int current_iter = -1; + + // Block configs + uint32_t num_blocks; + uint32_t num_m_blocks; + uint32_t num_n_blocks; + + // For SM90 multicast checks + uint32_t num_blocks_in_group; + bool is_peer_cta_alive = true; + + // For grouped GEMM + int* grouped_layout; + uint32_t current_group_idx = 0; + // Only used for masked layout + uint32_t current_m_cumsum = 0; + // Only used for countiguous psum layout + uint32_t last_psum_m = 0, current_psum_m, current_m_block_cumsum = 0; + // Only used for k-grouped layout + uint32_t current_shape_k, current_num_valid_groups = 0, current_k_cumsum = 0, current_sf_k_cumsum = 0; + uint32_t next_group_idx, next_shape_k; + + // Only used for k-grouped gemm + __device__ __forceinline__ void get_next_k_group(uint32_t &group_idx, uint32_t &shape_k) const { + for (; group_idx < kNumGroups; ++ group_idx) { + shape_k = __ldg(grouped_layout + group_idx); + if (shape_k > 0) + break; + } + } + + // ReSharper disable once CppPossiblyUninitializedMember + __device__ __forceinline__ explicit Scheduler(const uint32_t& shape_m, const uint32_t& shape_n, const uint32_t& shape_k, + int* grouped_layout = nullptr) { + num_m_blocks = ceil_div(shape_m, BLOCK_M); + num_n_blocks = ceil_div(shape_n, BLOCK_N); + current_shape_k = shape_k; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + num_blocks = num_m_blocks * num_n_blocks; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + num_blocks = num_m_blocks * num_n_blocks; + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + this->grouped_layout = grouped_layout; + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + this->grouped_layout = grouped_layout; + current_psum_m = __ldg(grouped_layout); + num_m_blocks = ceil_div(current_psum_m, BLOCK_M); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + this->grouped_layout = grouped_layout; + get_next_k_group(current_group_idx, current_shape_k); + next_group_idx = current_group_idx + 1; + get_next_k_group(next_group_idx, next_shape_k); + } + } + + __device__ __forceinline__ void get_swizzled_block_idx(const uint32_t& block_idx, uint32_t& m_block_idx, uint32_t& n_block_idx) { + DG_STATIC_ASSERT(kNum1DBlocksPerGroup % kNumMulticast == 0, "Invalid group size"); + + // Swizzle for better L2 usages + const auto& primary_num_blocks = kIsMulticastOnA ? num_n_blocks : num_m_blocks; + const auto& secondary_num_blocks = kIsMulticastOnA ? num_m_blocks : num_n_blocks; + const auto& num_blocks_per_group = secondary_num_blocks * kNum1DBlocksPerGroup; + const auto& group_idx = block_idx / num_blocks_per_group; + auto first_block_idx = group_idx * kNum1DBlocksPerGroup; + auto in_group_idx = block_idx % num_blocks_per_group; + num_blocks_in_group = min(kNum1DBlocksPerGroup, primary_num_blocks - first_block_idx); + + // Fix unaligned TMA multicast + // NOTES: for SM90 only, as SM90 can dynamically disable TMA multicast + // while SM100 uses 2-CTA, which can not be dynamically disabled +#if __CUDA_ARCH__ < 1000 + if (kNumMulticast > 1 and num_blocks_in_group % 2 != 0) { + if (in_group_idx < (num_blocks_in_group ^ 1) * secondary_num_blocks) { + num_blocks_in_group = num_blocks_in_group ^ 1; + } else { + in_group_idx = in_group_idx - (num_blocks_in_group ^ 1) * secondary_num_blocks; + first_block_idx += num_blocks_in_group ^ 1; + num_blocks_in_group = 1; + } + } +#endif + + // Convert to final M/N block indices + // `kIsMulticastOnA == true` leads to groups on N + if constexpr (kIsMulticastOnA) { + m_block_idx = in_group_idx / num_blocks_in_group; + n_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + } else { + m_block_idx = first_block_idx + in_group_idx % num_blocks_in_group; + n_block_idx = in_group_idx / num_blocks_in_group; + } + } + + template + __device__ __forceinline__ uint32_t get_global_idx(const uint32_t shape_dim, const uint32_t block_size, + const uint32_t& block_idx, const uint32_t& m_block_idx = 0) { + if constexpr (kGemmType == GemmType::Normal) { + return block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + const auto offset = kWithGroupOffset ? cute::max(0, __ldg(grouped_layout + m_block_idx * BLOCK_M)) : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::MGroupedMasked or kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + const auto offset = kWithGroupOffset ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + auto offset = 0; + if constexpr (kWithGroupOffset) { + if constexpr (kIndexType == IndexType::MN) + offset = current_group_idx * shape_dim; + else if constexpr (kIndexType == IndexType::K) + offset = current_k_cumsum; + else if constexpr (kIndexType == IndexType::SF_K) + offset = current_sf_k_cumsum; + } + return offset + block_idx * block_size; + } else if constexpr (kGemmType == GemmType::Batched) { + // Ignore kWithGroupOffset, and apply offset for IndexType::SF_K + const auto offset = kIndexType == IndexType::SF_K ? current_group_idx : 0; + return offset * shape_dim + block_idx * block_size; + } + } + + __device__ __forceinline__ bool get_next_block(uint32_t& m_block_idx, uint32_t& n_block_idx) { + const auto next_block_idx = (++ current_iter) * kNumSMs + blockIdx.x; + + if constexpr (kGemmType == GemmType::MGroupedMasked) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + num_m_blocks = ceil_div(static_cast(__ldg(grouped_layout + current_group_idx)), BLOCK_M); + const auto current_m_block_cumsum = current_m_cumsum + num_m_blocks; + if (next_block_idx < current_m_block_cumsum * num_n_blocks) + break; + + // Move to check the next group + current_group_idx ++, current_m_cumsum = current_m_block_cumsum; + } + + get_swizzled_block_idx(next_block_idx - current_m_cumsum * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::MGroupedContiguousWithPsumLayout) { + while (true) { + // Within current group + if (next_block_idx < (current_m_block_cumsum + num_m_blocks) * num_n_blocks) + break; + + // Move to check the next group + if (++ current_group_idx == kNumGroups) + return false; + + // NOTES: `num_m_blocks` varies with the increase of the group index + last_psum_m = align(current_psum_m, 128u); + current_psum_m = __ldg(grouped_layout + current_group_idx); + current_m_block_cumsum += num_m_blocks; + num_m_blocks = ceil_div(current_psum_m - last_psum_m, BLOCK_M); + } + + get_swizzled_block_idx(next_block_idx - current_m_block_cumsum * num_n_blocks, m_block_idx, n_block_idx); + + // NOTES: `last_psum_m` is aligned with 128 + m_block_idx += last_psum_m / BLOCK_M; + DG_STATIC_ASSERT(128 % BLOCK_M == 0, "Invalid BLOCK_M"); + } else if constexpr (kGemmType == GemmType::KGroupedContiguous) { + while (true) { + // End of the task + if (current_group_idx == kNumGroups) + return false; + + // Within current group + if (next_block_idx < (current_num_valid_groups + 1) * num_m_blocks * num_n_blocks) + break; + + // Move to check the next group + current_k_cumsum += current_shape_k; + current_sf_k_cumsum += ceil_div(current_shape_k, SF_K_ALIGNMENT); + current_num_valid_groups ++; + + current_group_idx = next_group_idx ++; + current_shape_k = next_shape_k; + get_next_k_group(next_group_idx, next_shape_k); + } + + get_swizzled_block_idx(next_block_idx - current_num_valid_groups * num_m_blocks * num_n_blocks, m_block_idx, n_block_idx); + } else if constexpr (kGemmType == GemmType::Batched) { + if (next_block_idx >= num_blocks * kNumGroups) + return false; + + current_group_idx = next_block_idx / num_blocks; + const auto& block_idx = next_block_idx - current_group_idx * num_blocks; + if constexpr (kIsMulticastOnA) { + m_block_idx = block_idx / num_n_blocks; + n_block_idx = block_idx % num_n_blocks; + } else { + m_block_idx = block_idx % num_m_blocks; + n_block_idx = block_idx / num_m_blocks; + } + } else { + if (next_block_idx >= num_blocks) + return false; + + // For SM90 only + // NOTES: we don't have to set `is_peer_cta_alive` for masked grouped GEMM, as it must be aligned + is_peer_cta_alive = num_n_blocks % kNumMulticast == 0 or // Always aligned on N (constant bypass) + num_m_blocks % kNumMulticast == 0 or // Always aligned on M (constant bypass) + (next_block_idx ^ 1) < num_blocks; // Peer CTA in bound + get_swizzled_block_idx(next_block_idx, m_block_idx, n_block_idx); + } + return true; + } + + // For SM90 only + __device__ __forceinline__ bool is_tma_multicast_valid(const uint32_t& m_block_idx) const { + if (num_blocks_in_group == 1) + return false; + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::MGroupedMasked or + kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched) { + return true; + } else { + DG_STATIC_ASSERT(kGemmType == GemmType::MGroupedContiguous, "Invalid Gemm type"); + if constexpr (kIsMulticastOnA) { + return true; + } else { + const auto& group_idx = __ldg(grouped_layout + m_block_idx * BLOCK_M); + const auto& peer_group_idx = __ldg(grouped_layout + (m_block_idx ^ 1) * BLOCK_M); + return group_idx == peer_group_idx; + } + } + } + + // For SM90 only + // ReSharper disable once CppNotAllPathsReturnValue + __device__ __forceinline__ bool is_computation_valid(const uint32_t& m_block_idx, const uint32_t& m_offset) const { + if constexpr (kGemmType == GemmType::Normal or kGemmType == GemmType::Batched) { + return true; + } else if constexpr (kGemmType == GemmType::MGroupedContiguous) { + return __ldg(grouped_layout + m_offset + m_block_idx * BLOCK_M) >= 0; + } else if constexpr (kGemmType == GemmType::MGroupedMasked) { + return m_offset + m_block_idx * BLOCK_M < __ldg(grouped_layout + current_group_idx); + } else { + // Unreachable + DG_TRAP_ONLY_DEVICE_ASSERT(false); + } + } +}; + +#pragma clang diagnostic pop + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..537cbe0818bf0c2c8b693e5d49862dcfbc7adb11 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm100_utils.cuh @@ -0,0 +1,266 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace deep_gemm::sm100 { + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_smem_desc(cute::UMMA::LayoutType layout, void* smem_ptr, + uint32_t stride_byte_offset, uint32_t leading_byte_offset) { + cute::UMMA::SmemDescriptor desc; + + // Set the version for SM100 + desc.version_ = 1; + + // Legacy mode + desc.lbo_mode_ = 0; + + // Layout + desc.layout_type_ = static_cast(layout); + + // Start address + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); + + // Base offset + desc.base_offset_ = 0; + + // SBO and LBO + desc.stride_byte_offset_ = stride_byte_offset >> 4; + desc.leading_byte_offset_ = leading_byte_offset >> 4; + + return desc; +} + +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_sf_desc(void* smem_ptr) { + // NOTES: the UTCCP layout is K-major by default + // Atom size: 8 x 128 bits + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // Since the UTCCP we used is 128b-wide (only 1 atom on K), so LBO can be zero + return make_smem_desc(cute::UMMA::LayoutType::SWIZZLE_NONE, smem_ptr, 8 * 16, 0); +} + +__device__ __forceinline__ +void replace_smem_desc_addr(cute::UMMA::SmemDescriptor& desc, const void* smem_ptr) { + const auto uint_ptr = cute::cast_smem_ptr_to_uint(smem_ptr); + desc.start_address_ = static_cast(uint_ptr >> 4); +} + +__device__ __forceinline__ +static uint32_t get_atom_base(const cute::UMMA::LayoutType& layout_type) { + return layout_type == cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 32 : 16; +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::UMMA::LayoutType to_umma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + // A special case + if constexpr ((cute::is_same_v and kMajorMode == cute::UMMA::Major::MN) or kUseBase32) { + DG_STATIC_ASSERT(kUseBase32, "Invalid swizzling base"); + return cute::UMMA::LayoutType::SWIZZLE_128B_BASE32B; + } + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 16) return cute::UMMA::LayoutType::SWIZZLE_NONE; + if constexpr (kSwizzleMode == 32) return cute::UMMA::LayoutType::SWIZZLE_32B; + if constexpr (kSwizzleMode == 64) return cute::UMMA::LayoutType::SWIZZLE_64B; + if constexpr (kSwizzleMode == 128) return cute::UMMA::LayoutType::SWIZZLE_128B; +} + +template +__device__ __forceinline__ +constexpr uint32_t get_umma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +template +__device__ __forceinline__ +uint32_t advance_umma_desc_lo(const uint32_t& base, const uint32_t& offset, const uint32_t& k_idx) { + return base + (((offset + k_idx * get_umma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::UMMA::SmemDescriptor make_umma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_umma_desc_stride_k(); + const auto& layout_type = to_umma_layout_type(); + const auto& num_non_contiguous = 128 / get_atom_base(layout_type); + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be the same as `BLOCK_K * sizeof(dtype_t)` + // also, atom index must be 0, so that each block has exactly one swizzle atom on the K axis + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(layout_type, + base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, + stride_byte_offset, leading_byte_offset); + } +} + +__device__ __forceinline__ +uint64_t make_runtime_instr_desc_with_sf_id(cute::UMMA::InstrDescriptorBlockScaled desc, const uint32_t& sfa_id, const uint32_t& sfb_id) { + desc.a_sf_id_ = sfa_id, desc.b_sf_id_ = sfb_id; + return static_cast(static_cast(desc)) << 32; +} + +template +__device__ constexpr uint32_t get_num_aligned_tmem_cols() { + DG_STATIC_ASSERT(kNumCols <= 512, "Too many tensor memory columns"); + if (kNumCols <= 32) return 32; + if (kNumCols <= 64) return 64; + if (kNumCols <= 128) return 128; + if (kNumCols <= 256) return 256; + return 512; +} + +__device__ __forceinline__ void tcgen05_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;"); +} + +__device__ __forceinline__ void tcgen05_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;"); +} + +__device__ __forceinline__ +void tma_gather4(const void* desc_ptr, cutlass::arch::ClusterTransactionBarrier &mbarrier, void* smem_ptr, int col_idx, int4 row_idxs, uint64_t cache_hint) { + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + uint32_t mbarrier_addr = cute::cast_smem_ptr_to_uint(&mbarrier); + asm volatile( + "cp.async.bulk.tensor.2d.shared::cta.global.tile::gather4.mbarrier::complete_tx::bytes.cta_group::1.L2::cache_hint [%0], [%1, {%2, %3, %4, %5, %6}], [%7], %8;\n" + : + : "r"(smem_addr), "l"(desc_ptr), "r"(col_idx), + "r"(row_idxs.x), "r"(row_idxs.y), "r"(row_idxs.z), "r"(row_idxs.w), + "r"(mbarrier_addr), "l"(cache_hint) + : "memory" + ); +} + +// UMMA versions with relaxed assertions +struct SM100_MMA_F16BF16_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_F16BF16_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +struct SM100_MMA_MXF8F6F4_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_MXF8F6F4_2x1SM_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc, + uint32_t const& tmem_sfa, + uint32_t const& tmem_sfb) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c), + "r"(tmem_sfa), "r"(tmem_sfb)); + } +}; + +struct SM100_MMA_F16BF16_WS_SS { + __device__ static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scale_c, + uint64_t const& desc) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.ws.cta_group::1.kind::f16 [%0], %1, %2, %3, p; \n\t" + "}\n" + :: "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(static_cast(desc >> 32)), "r"(scale_c)); + } +}; + +} // namespace `deep_gemm::sm100` diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0874b675bc3e2eaaa90c5ff508a7da0bdd8f0830 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/sm90_utils.cuh @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm::sm90 { + +template +struct FP8MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 32; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct FP8MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (N == 8) return MMA_64x8x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 16) return MMA_64x16x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 24) return MMA_64x24x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 32) return MMA_64x32x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 40) return MMA_64x40x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 48) return MMA_64x48x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 56) return MMA_64x56x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 64) return MMA_64x64x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 72) return MMA_64x72x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 80) return MMA_64x80x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 88) return MMA_64x88x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 96) return MMA_64x96x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 104) return MMA_64x104x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 112) return MMA_64x112x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 120) return MMA_64x120x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 128) return MMA_64x128x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 136) return MMA_64x136x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 144) return MMA_64x144x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 152) return MMA_64x152x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 160) return MMA_64x160x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 168) return MMA_64x168x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 176) return MMA_64x176x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 184) return MMA_64x184x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 192) return MMA_64x192x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 200) return MMA_64x200x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 208) return MMA_64x208x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 216) return MMA_64x216x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 224) return MMA_64x224x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 232) return MMA_64x232x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 240) return MMA_64x240x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 248) return MMA_64x248x32_F32E4M3E4M3_SS_TN(); + if constexpr (N == 256) return MMA_64x256x32_F32E4M3E4M3_SS_TN(); + } + + static constexpr auto select_type() { + return FP8MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct BF16MMA { + + template + __forceinline__ __device__ static void call_fma_impl(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(desc_a, desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(uint64_t const& desc_a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(desc_a, desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 16; + static constexpr int kNumAccum = M * N / 128; +}; + +template +constexpr cute::SM90::GMMA::Major to_sm90_major() { + DG_STATIC_ASSERT(kMajor == cute::UMMA::Major::K or kMajor == cute::UMMA::Major::MN, "Invalid major-ness"); + return kMajor == cute::UMMA::Major::K ? cute::SM90::GMMA::Major::K : cute::SM90::GMMA::Major::MN; +} + +template +struct BF16MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + constexpr auto kGMMAMajorA = to_sm90_major(); + constexpr auto kGMMAMajorB = to_sm90_major(); + if constexpr (N == 8) return MMA_64x8x16_F32BF16BF16_SS(); + if constexpr (N == 16) return MMA_64x16x16_F32BF16BF16_SS(); + if constexpr (N == 24) return MMA_64x24x16_F32BF16BF16_SS(); + if constexpr (N == 32) return MMA_64x32x16_F32BF16BF16_SS(); + if constexpr (N == 40) return MMA_64x40x16_F32BF16BF16_SS(); + if constexpr (N == 48) return MMA_64x48x16_F32BF16BF16_SS(); + if constexpr (N == 56) return MMA_64x56x16_F32BF16BF16_SS(); + if constexpr (N == 64) return MMA_64x64x16_F32BF16BF16_SS(); + if constexpr (N == 72) return MMA_64x72x16_F32BF16BF16_SS(); + if constexpr (N == 80) return MMA_64x80x16_F32BF16BF16_SS(); + if constexpr (N == 88) return MMA_64x88x16_F32BF16BF16_SS(); + if constexpr (N == 96) return MMA_64x96x16_F32BF16BF16_SS(); + if constexpr (N == 104) return MMA_64x104x16_F32BF16BF16_SS(); + if constexpr (N == 112) return MMA_64x112x16_F32BF16BF16_SS(); + if constexpr (N == 120) return MMA_64x120x16_F32BF16BF16_SS(); + if constexpr (N == 128) return MMA_64x128x16_F32BF16BF16_SS(); + if constexpr (N == 136) return MMA_64x136x16_F32BF16BF16_SS(); + if constexpr (N == 144) return MMA_64x144x16_F32BF16BF16_SS(); + if constexpr (N == 152) return MMA_64x152x16_F32BF16BF16_SS(); + if constexpr (N == 160) return MMA_64x160x16_F32BF16BF16_SS(); + if constexpr (N == 168) return MMA_64x168x16_F32BF16BF16_SS(); + if constexpr (N == 176) return MMA_64x176x16_F32BF16BF16_SS(); + if constexpr (N == 184) return MMA_64x184x16_F32BF16BF16_SS(); + if constexpr (N == 192) return MMA_64x192x16_F32BF16BF16_SS(); + if constexpr (N == 200) return MMA_64x200x16_F32BF16BF16_SS(); + if constexpr (N == 208) return MMA_64x208x16_F32BF16BF16_SS(); + if constexpr (N == 216) return MMA_64x216x16_F32BF16BF16_SS(); + if constexpr (N == 224) return MMA_64x224x16_F32BF16BF16_SS(); + if constexpr (N == 232) return MMA_64x232x16_F32BF16BF16_SS(); + if constexpr (N == 240) return MMA_64x240x16_F32BF16BF16_SS(); + if constexpr (N == 248) return MMA_64x248x16_F32BF16BF16_SS(); + if constexpr (N == 256) return MMA_64x256x16_F32BF16BF16_SS(); + } + + static constexpr auto select_type() { + return BF16MMA(); + } + + using type = decltype(select_type()); +}; + +template +struct TF32MMARS { + + template + __forceinline__ __device__ static void call_fma_impl(uint32_t* a, uint64_t const& desc_b, float* d, bool scale_d, cute::index_sequence) { + using namespace cute::SM90::GMMA; + MMA::fma(a[0], a[1], a[2], a[3], desc_b, d[Idx]..., (scale_d ? ScaleOut::One : ScaleOut::Zero)); + } + + __forceinline__ __device__ static void wgmma(float* a, uint64_t const& desc_b, float* d, bool scale_d) { + call_fma_impl(reinterpret_cast(a), desc_b, d, scale_d, cute::make_index_sequence{}); + } + + static constexpr int M = 64; + static constexpr int N = N_; + static constexpr int K = 8; + static constexpr int kNumAccum = M * N / 128; +}; + +template +struct TF32MMASelector { + + static constexpr auto select_mma() { + using namespace cute::SM90::GMMA; + if constexpr (kUseRS) { + if constexpr (N == 8) return MMA_64x8x8_F32TF32TF32_RS_TN(); + if constexpr (N == 16) return MMA_64x16x8_F32TF32TF32_RS_TN(); + if constexpr (N == 32) return MMA_64x32x8_F32TF32TF32_RS_TN(); + if constexpr (N == 64) return MMA_64x64x8_F32TF32TF32_RS_TN(); + if constexpr (N == 128) return MMA_64x128x8_F32TF32TF32_RS_TN(); + if constexpr (N == 256) return MMA_64x256x8_F32TF32TF32_RS_TN(); + DG_STATIC_ASSERT(N == 8 or N == 16 or N == 32 or N == 64 or N == 128 or N == 256, "Invalid N"); + } + } + + static constexpr auto select_type() { + if constexpr (kUseRS) { + return TF32MMARS(); + } else { + DG_STATIC_ASSERT(kUseRS, "SS mode is not supported for TF32MMASelector for now"); + } + } + + using type = decltype(select_type()); +}; + +template +struct SM90_U32x2_STSM_N { + __device__ __forceinline__ static void + copy(dtype_t src_0, dtype_t src_1, void* smem_dst) { + const uint32_t src[2] = {*reinterpret_cast(&src_0), *reinterpret_cast(&src_1)}; + asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "l"(__cvta_generic_to_shared(smem_dst)), "r"(src[0]), "r"(src[1])); + } +}; + +struct SM90_U32x2_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst_0), "=r"(dst_1) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +struct SM90_U32x4_LDSM_N { + __device__ __forceinline__ static void + copy(uint32_t& dst_0, uint32_t& dst_1, uint32_t& dst_2, uint32_t& dst_3, void* smem_src) { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst_0), "=r"(dst_1), "=r"(dst_2), "=r"(dst_3) + : "l"(__cvta_generic_to_shared(smem_src))); + } +}; + +__forceinline__ __device__ void warpgroup_arrive() { + asm volatile("wgmma.fence.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_commit_batch() { + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +__forceinline__ __device__ void warpgroup_fence_operand(float& reg) { + asm volatile("" : "+f"(reg) :: "memory"); +} + +template +__forceinline__ __device__ void warpgroup_wait() { + DG_STATIC_ASSERT(N >= 0 and N <= 7, "WGMMA wait: N must be in range [0, 7]"); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +} + +template +__device__ cute::GmmaDescriptor make_smem_desc(PointerType smem_ptr, const int& layout_type, + const int& leading_byte_offset = 0, + const int& stride_byte_offset = 1024) { + // NOTES: the default LBO and SBO are for K-major types + cute::GmmaDescriptor desc; + const auto& uint_ptr = static_cast(__cvta_generic_to_shared(smem_ptr)); + desc.bitfield.start_address_ = uint_ptr >> 4; + desc.bitfield.layout_type_ = layout_type; + desc.bitfield.leading_byte_offset_ = leading_byte_offset >> 4; + desc.bitfield.stride_byte_offset_ = stride_byte_offset >> 4; + desc.bitfield.base_offset_ = 0; + return desc; +} + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ +constexpr uint32_t get_gmma_desc_stride_k() { + return kMajorMode == cute::UMMA::Major::K ? 1 : get_inner_block_atom_size(); +} + +// ReSharper disable once CppNotAllPathsReturnValue +template +constexpr static cute::SM90::GMMA::LayoutType to_gmma_layout_type() { + DG_STATIC_ASSERT(kSwizzleMode == 0 or kSwizzleMode == 16 or + kSwizzleMode == 32 or kSwizzleMode == 64 or + kSwizzleMode == 128, "Invalid swizzling mode"); + + // Normal cases + if constexpr (kSwizzleMode == 0) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 16) return cute::SM90::GMMA::LayoutType::INTERLEAVE; + if constexpr (kSwizzleMode == 32) return cute::SM90::GMMA::LayoutType::B32; + if constexpr (kSwizzleMode == 64) return cute::SM90::GMMA::LayoutType::B64; + if constexpr (kSwizzleMode == 128) return cute::SM90::GMMA::LayoutType::B128; +} + +template +__device__ __forceinline__ +uint32_t advance_gmma_desc_lo(const uint32_t& base, const uint32_t& mn_idx, const uint32_t& k_idx, const uint32_t& offset = 0) { + return base + (((offset + mn_idx * BLOCK_K + k_idx * get_gmma_desc_stride_k()) * static_cast(sizeof(dtype_t))) >> 4u); +} + +template +__device__ __forceinline__ +cute::GmmaDescriptor make_gmma_desc(dtype_t* base_smem_ptr, uint32_t mn_idx, uint32_t k_idx) { + const uint32_t stride_k = get_gmma_desc_stride_k(); + const auto& layout_type = to_gmma_layout_type(); + constexpr uint32_t num_non_contiguous = 128 / 16; + if constexpr (kMajorMode == cute::UMMA::Major::K) { + // NOTES: for K-major layout, the swizzle must be 128B (also, atom index must be 0), as `BLOCK_K` is always 128 + DG_STATIC_ASSERT(kSwizzleMode == BLOCK_K * sizeof(dtype_t), "Unexpected value"); + + // Atom size: 8 x `kSwizzleMode` (in bytes, on K) + // {SBO, LBO} means the byte stride between atoms on {MN, K} + // NOTES: on K, there is only 1 atom as asserted previously, so LBO can be 0 + const uint32_t stride_byte_offset = num_non_contiguous * BLOCK_K * sizeof(dtype_t); + const uint32_t leading_byte_offset = 0; + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } else { + constexpr uint32_t BLOCK_MN_ATOM = get_inner_block_atom_size(); + + // Must have no in-atom MN-idx + // NOTES: no worries for the runtime assert, the `mn_idx` are constants at compilation time + DG_DEVICE_ASSERT(mn_idx % BLOCK_MN_ATOM == 0); + DG_STATIC_ASSERT(kSwizzleMode > 0, "Invalid swizzling"); + + // Atom size: `kSwizzleMode` (in bytes, on MN) x 8 + // NOTES: `kSwizzleMode == 16` mean non-swizzling but interleaving + // {SBO, LBO} means the byte stride between atoms on {K, MN} for swizzling + // {SBO, LBO} means the byte stride between atoms on {MN, K} for non-swizzling + uint32_t stride_byte_offset = num_non_contiguous * BLOCK_MN_ATOM * sizeof(dtype_t); + uint32_t leading_byte_offset = BLOCK_K * BLOCK_MN_ATOM * sizeof(dtype_t); + if constexpr (kSwizzleMode == 16) + swap(stride_byte_offset, leading_byte_offset); + return make_smem_desc(base_smem_ptr + mn_idx * BLOCK_K + k_idx * stride_k, static_cast(layout_type), + leading_byte_offset, stride_byte_offset); + } +} + +} // namespace `deep_gemm::sm90` diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/tma_utils.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/tma_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bd54adc231812a0c0e2fbbc130e18a27697a497c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/tma_utils.cuh @@ -0,0 +1,116 @@ +#pragma once + +#include +#include +#include + +namespace deep_gemm { + +template +constexpr uint32_t get_inner_block_atom_size() { + return kSwizzleMode == 0 ? BLOCK_INNER : kSwizzleMode / sizeof(dtype_t); +} + +template +__device__ __forceinline__ void +tma_copy(void const* desc_ptr, cutlass::arch::ClusterTransactionBarrier* barrier_ptr, + dtype_t* smem_ptr, const uint32_t& inner_idx, const uint32_t& outer_idx, + const uint32_t& num_tma_multicast = 1, const uint32_t& batch_idx = 0) { + DG_STATIC_ASSERT(static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL) == + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), "Invalid cache hint"); + constexpr uint32_t BLOCK_INNER_ATOM = get_inner_block_atom_size(); + + if constexpr (not kIs3DTMA) { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx); + } + } + #endif + } + } else { + if (num_tma_multicast == 1) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } else { + #if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) + // 2-CTA function will send signals to the leader CTA only + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + static_cast(cute::TMA::CacheHintSm100::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + #elif (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) + if (cute::block_rank_in_cluster() == 0) { + #pragma unroll + for (uint32_t i = 0; i < BLOCK_INNER / BLOCK_INNER_ATOM; ++ i) { + cute::SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, reinterpret_cast(barrier_ptr), + (1 << num_tma_multicast) - 1, static_cast(cute::TMA::CacheHintSm90::EVICT_NORMAL), + smem_ptr + i * BLOCK_OUTER * BLOCK_INNER_ATOM, + inner_idx + i * BLOCK_INNER_ATOM, outer_idx, batch_idx); + } + } + #endif + } + } +} + +// Tensormap related +__device__ __forceinline__ void tensor_map_release_cta() { + asm volatile ("fence.proxy.tensormap::generic.release.cta;"); +} + +__device__ __forceinline__ void tensor_map_acquire_cta(const cute::TmaDescriptor* gmem_desc_ptr) { + auto gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + asm volatile ("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" :: "l"(gmem_int_desc) : "memory"); +} + +__device__ __forceinline__ void tensor_map_replace_global_addr_in_smem(cute::TmaDescriptor* smem_desc, const void* new_addr) { + auto smem_int_desc = static_cast(__cvta_generic_to_shared(smem_desc)); + const auto new_int64_addr = reinterpret_cast(new_addr); + asm volatile ("tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" :: "r"(smem_int_desc), "l"(new_int64_addr)); +} + +__device__ __forceinline__ void tensor_map_replace_global_inner_dim_stride_in_smem(cute::TmaDescriptor* smem_desc, const uint32_t& new_dim, const uint64_t& new_stride) { + auto smem_int_desc = __cvta_generic_to_shared(smem_desc); + asm volatile ("tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" :: "l"(smem_int_desc), "r"(new_dim)); +#if ((__CUDACC_VER_MAJOR__ > 12) or ((__CUDACC_VER_MAJOR__ == 12) and (__CUDACC_VER_MINOR__ >= 3))) + asm volatile("tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" :: "l"(smem_int_desc), "l"(new_stride)); +#else + DG_STATIC_ASSERT(false, "Invalid CUDA version"); +#endif +} + +} // namespace `deep_gemm` diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/types.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..410c5469ef348500f339ae0ced5d7bf171fdd42c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/types.hpp @@ -0,0 +1,41 @@ +#pragma once + +namespace deep_gemm { + +enum class MmaKind { + BF16 = 0, + MXFP8FP4 = 1, +}; + +constexpr __host__ __device__ int get_element_size(const MmaKind& mma_kind) { + switch (mma_kind) { + case MmaKind::BF16: return 2; + case MmaKind::MXFP8FP4: return 1; + default: return 0; + } +} + +enum class GemmType { + Normal = 0, + MGroupedContiguous = 1, + MGroupedMasked = 2, + KGroupedContiguous = 3, + Batched = 4, + MGroupedContiguousWithPsumLayout = 5, +}; + +constexpr __host__ __device__ bool is_m_grouped_contiguous(const GemmType& gemm_type) { + switch (gemm_type) { + case GemmType::MGroupedContiguous: return true; + case GemmType::MGroupedContiguousWithPsumLayout: return true; + default: return false; + } +} + +enum class KernelType { + Kernel1D1D = 0, + Kernel1D2D = 1, + KernelNoSF = 2 +}; + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/utils.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..8fb6c2fc53b6d1eb067d13c113462a9f7de4133a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/common/utils.cuh @@ -0,0 +1,183 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "cute_tie.cuh" + +#ifdef __CLION_IDE__ + +__host__ __device__ __forceinline__ void host_device_printf(const char* format, ...) { + asm volatile("trap;"); +} + +#define printf host_device_printf +#endif + +#ifndef DG_DEVICE_ASSERT +#define DG_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) { \ + printf("Assertion failed: %s:%d, condition: %s\n", __FILE__, __LINE__, #cond); \ + asm("trap;"); \ + } \ +} while (0) +#endif + +#ifndef DG_TRAP_ONLY_DEVICE_ASSERT +#define DG_TRAP_ONLY_DEVICE_ASSERT(cond) \ +do { \ + if (not (cond)) \ + asm("trap;"); \ +} while (0) +#endif + +#ifndef DG_STATIC_ASSERT +#define DG_STATIC_ASSERT(cond, ...) static_assert(cond, __VA_ARGS__) +#endif + +namespace deep_gemm { + +template +struct PatternVisitor { + FuncT func; + + __device__ __host__ + explicit PatternVisitor(FuncT&& func): func(std::forward(func)) {} + + __device__ __host__ + auto operator [](const uint32_t& i) { + return func(i); + } +}; + +template +__device__ __host__ T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ constexpr T constexpr_ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +template +__device__ __host__ T align(T a, T b) { + return ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_align(T a, T b) { + return constexpr_ceil_div(a, b) * b; +} + +template +__device__ __host__ constexpr T constexpr_gcd(T a, T b) { + return b == 0 ? a : constexpr_gcd(b, a % b); +} + +template +__forceinline__ __device__ void swap(T& a, T& b) { + T temp = a; + a = b; + b = temp; +} + +__forceinline__ __device__ uint32_t get_sm_idx() { + uint32_t sm_idx; + asm ("mov.u32 %0, %%smid;" : "=r"(sm_idx)); + return sm_idx; +} + +__forceinline__ __device__ uint32_t get_lane_idx() { + uint32_t lane_id; + asm ("mov.u32 %0, %laneid;" : "=r"(lane_id)); + return lane_id; +} + +__device__ __forceinline__ uint32_t ld_shared(const uint32_t* ptr) { + uint32_t ret; + asm volatile("ld.shared.u32 %0, [%1];" : "=r"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float2 ld_shared(const float2* ptr) { + float2 ret; + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];" : "=f"(ret.x), "=f"(ret.y) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float4 ld_shared(const float4* ptr) { + float4 ret; + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];" : "=f"(ret.x), "=f"(ret.y), "=f"(ret.z), "=f"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ uint4 ld_shared(const uint4* ptr) { + uint4 ret; + asm volatile("ld.shared.v4.u32 {%0, %1, %2, %3}, [%4];" : "=r"(ret.x), "=r"(ret.y), "=r"(ret.z), "=r"(ret.w) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ float ld_shared(const float* ptr) { + float ret; + asm volatile("ld.shared.f32 %0, [%1];" : "=f"(ret) : "l"(__cvta_generic_to_shared(ptr))); + return ret; +} + +__device__ __forceinline__ void st_shared(const float* ptr, float val) { + asm volatile("st.shared.f32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val)); +} + +__device__ __forceinline__ void st_shared(const float2* ptr, float2 val) { + asm volatile("st.shared.v2.f32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "f"(val.x), "f"(val.y)); +} + +__device__ __forceinline__ void st_shared(const uint32_t* ptr, uint32_t val) { + asm volatile("st.shared.u32 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "r"(val)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y) { + asm volatile("st.shared.v2.u32 [%0], {%1, %2};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y)); +} + +__device__ __forceinline__ void st_shared(const void* ptr, uint32_t x, uint32_t y, uint32_t z, uint32_t w) { + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};" :: "l"(__cvta_generic_to_shared(ptr)), "r"(x), "r"(y), "r"(z), "r"(w)); +} + +__device__ __forceinline__ void st_shared(const __int128_t* ptr, __int128_t val) { + asm volatile("st.shared.b128 [%0], %1;" :: "l"(__cvta_generic_to_shared(ptr)), "q"(val)); +} + +template +__device__ __forceinline__ int cast_into_bf16_and_pack(old_t& x, old_t& y) { + auto bf16x2 = __float22bfloat162_rn({*reinterpret_cast(&x), *reinterpret_cast(&y)}); + return *reinterpret_cast(&bf16x2); +} + +__device__ __forceinline__ void prefetch_l1(void *ptr) { + asm volatile("prefetch.global.L1 [%0];" :: "l"(ptr)); +} + +template +struct Vectorized { + static auto zeros() { + // TODO: add `ulonglong4` for SM100 once `__ldg` support this + if constexpr (kNumBytes > 0 and kNumBytes % 16 == 0) { + return make_uint4(0, 0, 0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 8 == 0) { + return make_uint2(0, 0); + } else if constexpr (kNumBytes > 0 and kNumBytes % 4 == 0) { + return 0; + } else { + DG_STATIC_ASSERT(kNumBytes > 0 and kNumBytes % 4 == 0, "Invalid vectorization"); + } + } + + using vec_t = decltype(zeros()); +}; + +} // namespace `deep_gemm` diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0227b3e80061409c4dcf89f3f402ce408751246f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bf16_gemm.cuh @@ -0,0 +1,482 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `umma_arrive()` overhead + constexpr bool kDoMergeStages = + kNumStages_ >= 8 and kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 8; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + DG_STATIC_ASSERT(BLOCK_K_ == 64, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + DG_STATIC_ASSERT(sizeof(cutlass::bfloat16_t) * LAYOUT_AD_M % kSwizzleAMode == 0, "Invalid swizzle A mode"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(nv_bfloat16); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + kNumEpilogueStages + i); }); + auto tensor_core_full_barrier = barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive only at the leader CTA + full_barriers[i]->init(kNumMulticast); + // Arrive at all CTAs + empty_barriers[i]->init(1); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + if constexpr (kTensorCoreUtilControl < 100) + tensor_core_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0, tensor_core_phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = (stage_idx + 1) % kNumStages; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, kNumMulticast, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, kNumMulticast, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, kNumMulticast, batch_idx); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (is_leader_cta) { + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes * kNumMulticast); + } else { + full_barriers[stage_idx]->arrive(0u); + } + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // UMMA and empty barrier arrival alias + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + uint32_t atom_k_idx = k * UMMA_K / BLOCK_ATOM_K; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, atom_k_idx * LOAD_BLOCK_N * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, atom_k_idx * LOAD_BLOCK_M * BLOCK_ATOM_K + w * WAVE_BLOCK_M * BLOCK_ATOM_K, k * UMMA_K % BLOCK_ATOM_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + + // Let tensor cores relax for lower possibility of frequency drop + DG_STATIC_ASSERT(kTensorCoreUtilControl > 0, "Invalid tensor utilization control"); + if constexpr (kTensorCoreUtilControl < 100) { + // For utilization control + umma_arrive(reinterpret_cast(tensor_core_full_barrier)); + + // Wait for last UMMA to be done + tensor_core_full_barrier->wait(tensor_core_phase); + tensor_core_phase ^= 1; + + // Sleep for certain cycles + constexpr static uint64_t kNumUMMACycles = (2ull * LAYOUT_AD_M * kNumMWaves * BLOCK_N * BLOCK_K) / 8192ull; + constexpr static uint64_t kNumDummyCycles = (100ull - kTensorCoreUtilControl) * kNumUMMACycles / kTensorCoreUtilControl; + const auto& start_clock = clock64(); + if (cute::elect_one_sync()) + while (clock64() - start_clock < kNumDummyCycles) {} + __syncwarp(); + } + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + __syncwarp(); + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh new file mode 100644 index 0000000000000000000000000000000000000000..86303347d9c7a3a93b65a16d6ad4a7b73eb2ad1a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_bmk_bnk_mn.cuh @@ -0,0 +1,265 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumThreads, 1) +sm100_bmn_bnk_mn_gemm_impl(uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t kNumTMAStoreStages = 2; + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == LAYOUT_AD_M and BLOCK_N == 128 and BLOCK_K == 64, "Invalid block size"); + DG_STATIC_ASSERT(kSwizzleABMode == 128 and kSwizzleCDMode == 128, "Invalid swizzle mode"); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Shared memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(cutlass::bfloat16_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(cutlass::bfloat16_t); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Fill D/A/B + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (i * SMEM_CD_SIZE_PER_STAGE)); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto tmem_full_barrier = barrier_start_ptr + (kNumStages * 2); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 2 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx == 0) { + // TMA load warp + for (uint32_t s = 0; s < num_total_stages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + uint32_t m_idx = BLOCK_M * m_block_idx; + uint32_t n_idx = BLOCK_N * n_block_idx; + uint32_t sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + uint32_t k_idx = sk_idx % SHAPE_K; + uint32_t s_idx = sk_idx / SHAPE_K; + + // Issue TMAs + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx + s_idx * SHAPE_M); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, n_idx + s_idx * SHAPE_N); + } + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + if (cute::elect_one_sync()) + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } else if (warp_idx == 1) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + constexpr uint32_t UMMA_M = LAYOUT_AD_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::bfloat16_t); + auto instr_desc = cute::UMMA::make_instr_desc(); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Wait tensor memory empty barrier arrival + tcgen05_after_thread_sync(); + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA in the leader CTA + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, stage_idx); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, stage_idx); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, 0, k * UMMA_K); + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + SM100_MMA_F16BF16_SS::fma(a_desc, b_desc, 0, s > 0 or k > 0, runtime_instr_desc); + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + if (warp_idx == 2) + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s) { + // Wait shared memory to be released + if (s >= kNumTMAStoreStages) { + if (warp_idx == 0 and cute::elect_one_sync()) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + } + + // The pipeline stage + const auto tma_stage_idx = s % kNumTMAStoreStages; + const auto m_idx = m_block_idx * BLOCK_M; + const auto n_idx = n_block_idx * BLOCK_N + s * STORE_BLOCK_N; + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier(kNumThreads).sync(); + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy(&tensor_map_d, smem_cd[tma_stage_idx], n_idx, m_idx); + cute::tma_store_arrive(); + } + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is doing TMA stores + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..45a603add3f494aed51dce7aec53b5545bdc23f4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_gemm_1d1d.cuh @@ -0,0 +1,563 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__global__ void __launch_bounds__(kNumNonEpilogueThreads + kNumEpilogueThreads, 1) +sm100_fp8_gemm_1d1d_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + using Allocator = cute::conditional_t; + + // GEMM with accumulation must have FP32 output + if constexpr (kWithAccumulation) + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + + // Configs + constexpr uint32_t LAYOUT_AD_M = 128; + constexpr uint32_t WAVE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t kNumMWaves = BLOCK_M / WAVE_BLOCK_M; + constexpr uint32_t kNumTMAStoreStages = 2; + constexpr uint32_t kNumUTCCPAlignedElems = 128; + DG_STATIC_ASSERT(BLOCK_K == 128, "Invalid block K"); + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0 and 2 % kNumMWaves == 0, "Invalid block M"); + + constexpr uint32_t kNumSFAStagesPerLoad = kGranKA == 32 ? 1 : 4; + constexpr uint32_t kNumSFBStagesPerLoad = kGranKB == 32 ? 1 : 4; + DG_STATIC_ASSERT(kGranKA == 32 or kGranKA == 128, "Invalid granularity K for A"); + DG_STATIC_ASSERT(kGranKB == 32 or kGranKB == 128, "Invalid granularity K for B"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + const uint32_t shape_sfa_k = ceil_div(shape_k, kGranKA * 4); + const uint32_t shape_sfb_k = ceil_div(shape_k, kGranKB * 4); + + // Utils + bool is_leader_cta = cute::block_rank_in_cluster() == 0; + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // 2-CTA MMA + constexpr uint32_t LOAD_BLOCK_M = BLOCK_M / (kIsMulticastOnA ? kNumMulticast: 1); + constexpr uint32_t LOAD_BLOCK_N = BLOCK_N / (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t STORE_BLOCK_M = cute::min(BLOCK_M, LAYOUT_AD_M); + constexpr uint32_t STORE_BLOCK_N = kSwizzleCDMode / sizeof(cd_dtype_t); + constexpr uint32_t kNumUMMAStoreThreads = STORE_BLOCK_M; + DG_STATIC_ASSERT(not kIsMulticastOnA or kNumMulticast == 1, "Invalid multicast"); + DG_STATIC_ASSERT(LOAD_BLOCK_M == BLOCK_M, "Only support tensor memory layout A/D"); + DG_STATIC_ASSERT(kNumMulticast == 1 or kNumMulticast == 2, "Only support 1/2 multicast"); + DG_STATIC_ASSERT(kNumUMMAStoreThreads % 32 == 0, "Invalid store block M"); + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE_PER_STAGE = STORE_BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_CD_SIZE = SMEM_CD_SIZE_PER_STAGE * kNumTMAStoreStages; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = LOAD_BLOCK_M * BLOCK_K * sizeof(a_dtype_t); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = LOAD_BLOCK_N * BLOCK_K * sizeof(b_dtype_t); + constexpr uint32_t SF_BLOCK_M = constexpr_align(BLOCK_M, kNumUTCCPAlignedElems); + constexpr uint32_t SF_BLOCK_N = constexpr_align(BLOCK_N, kNumUTCCPAlignedElems); + constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = SF_BLOCK_M * sizeof(uint32_t); + constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = SF_BLOCK_N * sizeof(uint32_t); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B must be aligned to 1024 bytes"); + DG_STATIC_ASSERT(kNumTMAStoreStages >= 1, "Invalid number of TMA stages"); + + // NOTES: Make sure we have enough shared memory for UMMA padding + static constexpr uint32_t UMMA_A_SIZE_PER_STAGE = constexpr_align(LOAD_BLOCK_M, LAYOUT_AD_M) * BLOCK_K * sizeof(a_dtype_t); + DG_STATIC_ASSERT(UMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for UMMA"); + + // Automatically deduce the number of epilogue stages (1 or 2), according to the tensor memory size + // TODO: test cases of `kNumMWaves == 2 and kNumEpilogueStages == 2` + constexpr uint32_t kNumSFATmemCols = SF_BLOCK_M / 32; + constexpr uint32_t kNumSFBTmemCols = SF_BLOCK_N / 32; + constexpr uint32_t kNumEpilogueStages = (2 * kNumMWaves * BLOCK_N + kNumSFATmemCols + kNumSFBTmemCols) > 512 ? 1 : 2; + + // Real tensor memory size and offsets + constexpr uint32_t kNumAccumTmemCols = kNumEpilogueStages * kNumMWaves * BLOCK_N; + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + constexpr uint32_t kTmemStartColOfSFA = kNumAccumTmemCols; + constexpr uint32_t kTmemStartColOfSFB = kNumAccumTmemCols + kNumSFATmemCols; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + + // D/A/B shared memory + auto smem_cd = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + i * SMEM_CD_SIZE_PER_STAGE); + }); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // SFA/SFB shared memory + auto sf_start_ptr = smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + i * SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = PatternVisitor([=](const uint32_t& i) { + return reinterpret_cast(sf_start_ptr + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * SMEM_SFB_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE) + + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto with_sf_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto tmem_full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + kNumEpilogueStages + i); }); + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 3 + kNumEpilogueStages * 2); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + // Arrive at all CTAs + full_barriers[i]->init(1); + empty_barriers[i]->init(1); + // Arrive only at the leader CTA + with_sf_full_barriers[i]->init(kNumMulticast * 32); + } + #pragma unroll + for (uint32_t i = 0; i < kNumEpilogueStages; ++ i) { + // Arrive at all CTAs + tmem_full_barriers[i]->init(1); + // Arrive only at the leader CTA + tmem_empty_barriers[i]->init(kNumMulticast * kNumUMMAStoreThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + Allocator().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + kNumMulticast > 1 ? cute::cluster_sync() : __syncthreads(); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + // Dispatch warps into different roles + if (warp_idx == 0 and cute::elect_one_sync()) { + // TMA load warp + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Compute offsets + // NOTES: the group is always concatenated with the outer dimension + uint32_t m_idx = scheduler.template get_global_idx<(kGemmType == GemmType::MGroupedMasked), IndexType::MN> ( + shape_m, BLOCK_M, m_block_idx); + uint32_t n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN> ( + shape_n, BLOCK_N, n_block_idx, m_block_idx); + + // NOTES: `k_idx` is actually the k index default for K-major, while `k_b_idx` may be MN-major + // And for all m-grouped GEMMs, A must be K-majored + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kGemmType == GemmType::Batched or + kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_idx = k_block_idx * BLOCK_K; + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Add 2 CTA offsets + if constexpr (kNumMulticast > 1) { + m_idx += kIsMulticastOnA ? (cute::block_rank_in_cluster() * LOAD_BLOCK_M) : 0; + n_idx += kIsMulticastOnA ? 0 : (cute::block_rank_in_cluster() * LOAD_BLOCK_N); + } + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_a_idx, m_idx, 1, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], m_idx, k_a_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_b_idx, n_idx, 1, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], n_idx, k_b_idx, 1, batch_idx); + auto num_arrival_bytes = SMEM_A_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2) + + SMEM_B_SIZE_PER_STAGE / (std::is_same_v ? 1 : 2); + + // Issue SFA and SFB TMAs at certain stages + // No swizzling, so one TMA for one SF is enough + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + tma_copy(&tensor_map_sfa, full_barriers[stage_idx], smem_sfa[stage_idx], m_block_idx * BLOCK_M, + scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::SF_K>(shape_sfa_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFAStagesPerLoad))); + num_arrival_bytes += BLOCK_M * sizeof(uint32_t); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + tma_copy(&tensor_map_sfb, full_barriers[stage_idx], smem_sfb[stage_idx], n_block_idx * BLOCK_N, + scheduler.template get_global_idx(shape_sfb_k, 1, ceil_div(k_idx, BLOCK_K * kNumSFBStagesPerLoad), m_block_idx)); + num_arrival_bytes += BLOCK_N * sizeof(uint32_t); + } + + // Arrive at full barriers + full_barriers[stage_idx]->arrive_and_expect_tx(num_arrival_bytes); + } + } + } else if (warp_idx == 1 and is_leader_cta) { + // MMA issue warp + // NOTES: only the leader CTA will do this + // Make instruction descriptor + // TODO: refactor `UMMA_M` calculation + constexpr uint32_t UMMA_M = LAYOUT_AD_M * (kIsMulticastOnA ? 1 : kNumMulticast); + constexpr uint32_t UMMA_N = BLOCK_N * (kIsMulticastOnA ? kNumMulticast : 1); + constexpr uint32_t UMMA_K = 32; + auto instr_desc = cute::UMMA::make_instr_desc_block_scaled(); + auto sf_desc = make_sf_desc(nullptr); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto a_desc = make_umma_desc(smem_a[0], 0, 0); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + uint32_t a_desc_lo = lane_idx < kNumStages ? a_desc.lo + lane_idx * SMEM_A_SIZE_PER_STAGE / 16 : 0u; + uint32_t b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Wait tensor memory empty barrier arrival + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + tmem_empty_barriers[accum_stage_idx]->wait(accum_phase_idx ^ 1); + tcgen05_after_thread_sync(); + + // Empty barrier arrival + auto empty_barrier_arrive = [&](const bool& do_tmem_full_arrive) { + auto umma_arrive = [](const uint64_t* barrier) { + if constexpr (kNumMulticast == 1) { + cutlass::arch::umma_arrive(barrier); + } else { + constexpr uint16_t kCTAMask = (1 << kNumMulticast) - 1; + cutlass::arch::umma_arrive_multicast_2x1SM(barrier, kCTAMask); + } + }; + umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + + // NOTES: the tensor memory accumulator pipeline has nothing to do with multicasting + if (do_tmem_full_arrive) + umma_arrive(reinterpret_cast(tmem_full_barriers[accum_stage_idx])); + }; + + // Launch MMAs + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA and SF-transpose arrival + with_sf_full_barriers[stage_idx]->wait(phase); + tcgen05_after_thread_sync(); + + // Do SF copy at certain stages + // NOTES: CUTLASS UTCCP's interface does not have `elect_one_sync`, we must do it by ourselves + // TODO: process shared memory descriptor by addition + using cute_utccp_t = cute::conditional_t; + const uint32_t sfa_stage_in_group_idx = k_block_idx % kNumSFAStagesPerLoad; + if (sfa_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFA + i * 4); + } + } + const uint32_t sfb_stage_in_group_idx = k_block_idx % kNumSFBStagesPerLoad; + if (sfb_stage_in_group_idx == 0 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) { + auto smem_ptr = smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems; + replace_smem_desc_addr(sf_desc, smem_ptr); + cute_utccp_t::copy(sf_desc, kTmemStartColOfSFB + i * 4); + } + } + __syncwarp(); + + // Issue UMMA in the leader CTA + using mma_t = cute::conditional_t; + const auto& a_desc_base_lo = __shfl_sync(0xffffffff, a_desc_lo, static_cast(stage_idx)); + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t sfa_id = (kGranKA == 32 ? k : sfa_stage_in_group_idx); + const uint32_t sfb_id = (kGranKB == 32 ? k : sfb_stage_in_group_idx); + const auto& runtime_instr_desc = make_runtime_instr_desc_with_sf_id(instr_desc, sfa_id, sfb_id); + + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, 0, k * UMMA_K); + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + DG_STATIC_ASSERT((WAVE_BLOCK_M * BLOCK_K) % 128 == 0, "Invalid swizzling offset"); + a_desc.lo = advance_umma_desc_lo(a_desc_base_lo, w * WAVE_BLOCK_M * BLOCK_K, k * UMMA_K); + mma_t::fma(a_desc, b_desc, + accum_stage_idx * kNumMWaves * BLOCK_N + w * BLOCK_N, + k_block_idx > 0 or k > 0, + runtime_instr_desc, + kTmemStartColOfSFA + w * (kNumUTCCPAlignedElems / 32), + kTmemStartColOfSFB); + } + } + } + + // Commit to the mbarrier object + // No explicit `tcgen05.fence::before_thread_sync` is needed, as this is implicitly performed by `tcgen05.commit` + empty_barrier_arrive(k_block_idx == num_total_k_blocks - 1); + } + } + + // To safely deconstruct barriers, we need another round of waits + const auto& iter_idx = scheduler.current_iter - 1; + if (kNumMulticast > 1 and iter_idx >= 0) { + const auto& accum_phase_idx = (iter_idx / kNumEpilogueStages) & 1; + tmem_empty_barriers[iter_idx % kNumEpilogueStages]->wait(accum_phase_idx); + } + } else if (warp_idx == 2) { + // UTCCP transposer + auto utccp_required_smem_warp_transpose = [&](const uint32_t* smem_ptr) { + DG_STATIC_ASSERT(kNumUTCCPAlignedElems == 128, "Invalid aligned elements"); + uint32_t values[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + values[i] = ld_shared(smem_ptr + (i ^ (lane_idx >> 3)) * 32 + lane_idx); + __syncwarp(); + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) + st_shared(smem_ptr + lane_idx * 4 + (i ^ (lane_idx >> 3)), values[i]); + }; + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait TMA arrival + full_barriers[stage_idx]->wait(phase); + + // Transpose for UTCCP at certain stages + if (k_block_idx % kNumSFAStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_M / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfa[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + if (k_block_idx % kNumSFBStagesPerLoad == 0) { + #pragma unroll + for (uint32_t i = 0; i < SF_BLOCK_N / kNumUTCCPAlignedElems; ++ i) + utccp_required_smem_warp_transpose(smem_sfb[stage_idx] + i * kNumUTCCPAlignedElems); + // TODO: figure out whether the proxy fence is valid for 2-CTA cases + cutlass::arch::fence_view_async_shared(); + } + + // Arrive + with_sf_full_barriers[stage_idx]->arrive(0u); + } + } + } else if (warp_idx >= kNumNonEpilogueThreads / 32 and warp_idx < (kNumNonEpilogueThreads + kNumUMMAStoreThreads) / 32) { + // Epilogue warp groups + const auto epilogue_warp_idx = warp_idx - (kNumNonEpilogueThreads / 32); + + // NOTES: tensor memory addresses are simplified, as the hardware will ignore the warp index bits, + // i.e., no need for `tmem_ptr |= (epilogue_warp_idx * 32) << 16`. + // NOTES: we also forbid two CTAs to share the same SM and its tensor memory + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(cd_dtype_t); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(STORE_BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Share store pipeline between blocks + uint32_t tma_stage_idx = 0; + auto advance_store_pipeline = [&]() { + tma_stage_idx = (tma_stage_idx + 1) % kNumTMAStoreStages; + }; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + auto accum_stage_idx = scheduler.current_iter % kNumEpilogueStages; + auto accum_phase_idx = (scheduler.current_iter / kNumEpilogueStages) & 1; + + // Wait UMMA arrival + tmem_full_barriers[accum_stage_idx]->wait(accum_phase_idx); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumEpilogueThreads == 128, "Epilogue threads not enough"); + DG_STATIC_ASSERT(BLOCK_N % STORE_BLOCK_N == 0, "Invalid block sizes"); + + // Iterate over M waves + #pragma unroll + for (uint32_t w = 0; w < kNumMWaves; ++ w) { + // Issue every swizzled atom and pipeline STSM and TMA store + constexpr uint32_t kNumStores = BLOCK_N / STORE_BLOCK_N; + #pragma unroll + for (uint32_t s = 0; s < kNumStores; ++ s, advance_store_pipeline()) { + // Wait shared memory to be released + if (epilogue_warp_idx == 0) + cute::tma_store_wait(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + + // The pipeline stage + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx) + w * WAVE_BLOCK_M; + const auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + s * STORE_BLOCK_N); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < STORE_BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = i + lane_idx * (kSwizzleCDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(LAYOUT_AD_M, kSwizzleCDMode / kNumBankGroupBytes)` + // - new: `(LAYOUT_AD_M * kSwizzleCDMode / kNumBankGroupBytes / 8, 8)` + // NOTES: "8" is the number of bank groups, "16" is the swizzling pattern + constexpr bool kHasShortcut = (kSwizzleCDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (i / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (i) : (bank_group_index % 8); + col ^= row % (kSwizzleCDMode / 16); + + // Source and destination memory address + uint32_t tmem_addr = accum_stage_idx * kNumMWaves * BLOCK_N + // Accumulator offset + w * BLOCK_N + // Wave offset + s * STORE_BLOCK_N + i * kNumElemsPerBankGroup; // In-block offset + auto smem_ptr = reinterpret_cast(smem_cd[tma_stage_idx]) + // Base pointer + epilogue_warp_idx * 32 * kSwizzleCDMode + // Warp offset + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + if constexpr (cute::is_same_v) { + // For FP32 output, read and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + } else { + // For BF16 output, read, cast and store + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 8 and cute::is_same_v, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b8x::copy(tmem_addr, + values[0], values[1], values[2], values[3], + values[4], values[5], values[6], values[7]); + cutlass::arch::fence_view_async_tmem_load(); + st_shared(smem_ptr, + cast_into_bf16_and_pack(values[0], values[1]), + cast_into_bf16_and_pack(values[2], values[3]), + cast_into_bf16_and_pack(values[4], values[5]), + cast_into_bf16_and_pack(values[6], values[7])); + } + } + + // Notify tensor memory empty (only at the leader CTA) arrival ASAP + // NOTES: only the last stage needs to do this + if (w == kNumMWaves - 1 and s == BLOCK_N / STORE_BLOCK_N - 1) { + tcgen05_before_thread_sync(); + tmem_empty_barriers[accum_stage_idx]->arrive(0u); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumUMMAStoreThreads, 0); + if (epilogue_warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kGemmType == GemmType::Batched) { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], + n_idx, m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_cd[tma_stage_idx], n_idx, m_idx); + } + cute::tma_store_arrive(); + } + } + } + } + + // Deallocate tensor memory by the last UMMA store warp + // NOTES: warp 0 is waiting TMA store + if (epilogue_warp_idx == kNumUMMAStoreThreads / 32 - 1) + Allocator().free(0, kNumTmemCols); + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..180a308b3279b38827741942917a31e103b15b52 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_mqa_logits.cuh @@ -0,0 +1,404 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // Normally, `h (kNumHeads) == 32` and `d (kHeadDim) == 64` + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warp_in_group_idx = warp_idx % 4; + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, 512u); + + // Align to 512 bytes for swizzle-64B + extern __shared__ __align__(512) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_WEIGHT_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % 512 == 0, "Unaligned TMA swizzling"); + + // TMA configs + constexpr uint32_t kNumTmemCols = BLOCK_Q * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + i); }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups + i); }); + + // Tensor memory allocation + auto tmem_ptr_in_smem = reinterpret_cast(barrier_ptr + kNumQStages * 2 + kNumKVStages * 2 + kNumMathWarpGroups * 2); + + // Initialize barriers + DG_STATIC_ASSERT(kNumSpecializedThreads % 128 == 0 and kNumSpecializedThreads >= 64, "Invalid threads"); + const bool& is_tma_load_warp = (warp_idx == (kNumMathThreads / 32)); + const bool& is_umma_warp = (warp_idx == (kNumMathThreads / 32 + 1)); + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (is_umma_warp) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 24; + constexpr uint32_t kNumMathRegisters = 240; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = BLOCK_Q * kNumHeads; + + if (is_tma_load_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue UMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(((num_total_kv_blocks + kv_block_idx) & 1) ^ 1); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } else if (warp_idx >= kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_dealloc(); + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const auto& warp_offset = warp_idx * 32; + const auto& v_offset = lane_idx; + + // Preload weights + constexpr uint32_t kNumWeightsInReg = cute::min(52, kNumHeads); + float weights[BLOCK_Q][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) { + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_offset); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait((num_total_kv_blocks + kv_block_idx) & 1); + tcgen05_after_thread_sync(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + + constexpr uint32_t kNumLDTMElems = kNumHeads * BLOCK_Q; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid kNumLDTMElems"); + uint32_t shifted_accum[kNumLDTMElems]; + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_offset and kv_offset + v_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_offset - seq_k_start[i]] = result; + } else { + logits[q_idx * stride_logits + kv_offset + v_offset] = result; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + + // Free tensor memory + __syncthreads(); + if (is_tma_load_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7058c40f4f195de94184d3e7ebc6f9aa2eb3670f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_fp8_paged_mqa_logits.cuh @@ -0,0 +1,398 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; +using namespace deep_gemm::sm100; + +template +__global__ __launch_bounds__(kNumSpecializedThreads + kNumMathThreads, 1) +void sm100_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumSpecializedThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = SPLIT_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = SPLIT_KV * sizeof(float); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q and KV data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i); + }); + constexpr auto smem_offset = SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages; + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SCALE_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + + // Barriers and TMEM pointer on shared memory + const auto barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages + i; }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + kNumQStages * 2 + kNumKVStages + i; }); + const auto umma_barrier_ptr = barrier_ptr + kNumQStages * 2 + kNumKVStages * 2; + auto full_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + i; }); + auto empty_umma_barriers = PatternVisitor([&](const uint32_t& i) { return umma_barrier_ptr + kNumMathWarpGroups + i; }); + auto tmem_ptr_in_smem = reinterpret_cast(umma_barrier_ptr + kNumMathWarpGroups * 2); + + constexpr uint32_t kNumTmemCols = kNextN * kNumHeads * kNumMathWarpGroups; + DG_STATIC_ASSERT(kNumTmemCols <= 512, "Too many tensor memory"); + const bool& is_math_warp = (warp_idx < kNumMathWarpGroups * 4); + const bool& is_tma_load_warp = (warp_idx == kNumMathWarpGroups * 4); + const bool& is_umma_warp = (warp_idx == kNumMathWarpGroups * 4 + 1); + + // Initialize barriers + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + cutlass::arch::fence_barrier_init(); + } + if (is_umma_warp) { + if (cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++i) { + full_umma_barriers[i]->init(1); + empty_umma_barriers[i]->init(128); + } + cutlass::arch::fence_barrier_init(); + } + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumSpecializedRegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Scheduler + constexpr uint32_t kNumBlocksPerSplit = SPLIT_KV / BLOCK_KV; + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumBlocksPerSplit, "Invalid `SPLIT_KV`"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + // UMMA settings + // Construct instruction with layout D + constexpr uint32_t UMMA_M = 128; + constexpr uint32_t UMMA_K = 32 / sizeof(cutlass::float_e4m3_t); + constexpr uint32_t UMMA_N = kNextN * kNumHeads; + DG_STATIC_ASSERT(SPLIT_KV == UMMA_M * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + + if (is_tma_load_warp) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + lane_idx < num_kv ? __ldg(block_table + q_idx * block_table_stride + (kv_idx + lane_idx)) : 0); + } + DG_STATIC_ASSERT(32 % kNumBlocksPerSplit == 0, "Invalid `UMMA_M`"); + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + int kv_block_idx[kNumBlocksPerSplit]; + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) + kv_block_idx[i] = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr + i); + kv_block_idx_ptr += kNumBlocksPerSplit; + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + if (cute::elect_one_sync()) { + #pragma unroll + for (int i = 0; i < kNumBlocksPerSplit; ++ i) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx] + (BLOCK_KV * kHeadDim) * i, + 0, 0, 1, kv_block_idx[i]); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx] + BLOCK_KV * i, + 0, kv_block_idx[i]); + } + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else if (is_umma_warp) { + cutlass::arch::warpgroup_reg_dealloc(); + + // Require full allocation + DG_TRAP_ONLY_DEVICE_ASSERT(ld_shared(tmem_ptr_in_smem) == 0); + + // Make UMMA desc + auto instr_desc = cute::UMMA::make_instr_desc(); + auto runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 1; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + if (q_idx != next_q_idx) { + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + } + + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + DG_STATIC_ASSERT(kHeadDim % UMMA_K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < kNumMathWarpGroups; ++ i) { + empty_umma_barriers[i]->wait(umma_phase); + tcgen05_after_thread_sync(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / UMMA_K; ++ k) { + auto a_desc = make_umma_desc( + smem_kv[kv_stage_idx], i * UMMA_M, k * UMMA_K); + auto b_desc = make_umma_desc( + smem_q[q_stage_idx], 0, k * UMMA_K); + cute::SM100_MMA_F8F6F4_SS::fma(a_desc, b_desc, i * UMMA_N, k, runtime_instr_desc); + } + cutlass::arch::umma_arrive(reinterpret_cast(full_umma_barriers[i])); + } + umma_phase ^= 1; + } + } else if (is_math_warp) { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // Offsets + const auto& tmem_start = __shfl_sync(0xffffffff, warpgroup_idx * UMMA_N, 0); + const uint32_t thread_idx = threadIdx.x; + + // Weights + constexpr uint32_t kNumWeightsInReg = (kNextN == 1 ? kNumHeads : cute::min(48, kNumHeads)); + float weights[kNextN][kNumWeightsInReg]; + DG_STATIC_ASSERT(kNumWeightsInReg % 4 == 0, "Invalid number of weights in registers"); + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + uint32_t umma_phase = 0; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + for (uint32_t j = 0; j < kNumWeightsInReg; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + kv_idx * BLOCK_KV; + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [SPLIT_KV, kHeadDim] -> [kNextN, SPLIT_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv = ld_shared(smem_kv_scales[kv_stage_idx] + thread_idx); + + // Wait UMMA arrival + full_umma_barriers[warpgroup_idx]->wait(umma_phase); + tcgen05_after_thread_sync(); + umma_phase ^= 1; + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + constexpr uint32_t kNumLDTMElems = kNumHeads * kNextN; + uint32_t shifted_accum[kNumLDTMElems]; + DG_STATIC_ASSERT(kNumLDTMElems == 32 or kNumLDTMElems == 64 or kNumLDTMElems == 128, "Invalid LDTM"); + auto tmem_load = [&](auto... Is) { + if constexpr (kNumLDTMElems == 32) { + cute::SM100_TMEM_LOAD_32dp32b32x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 64) { + cute::SM100_TMEM_LOAD_32dp32b64x::copy(tmem_start, shifted_accum[Is]...); + } else if constexpr (kNumLDTMElems == 128) { + cute::SM100_TMEM_LOAD_32dp32b128x::copy(tmem_start, shifted_accum[Is]...); + } + }; + [&](cute::index_sequence) { tmem_load(Is...); }(cute::make_index_sequence{}); + cutlass::arch::fence_view_async_tmem_load(); + + tcgen05_before_thread_sync(); + empty_umma_barriers[warpgroup_idx]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto accum = reinterpret_cast(shifted_accum + i * kNumHeads); + + auto sum_0 = make_float2(0, 0); + auto sum_1 = make_float2(0, 0); + + const auto& transform_reg = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(weights[i][j], weights[i][j + 1]); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = 0; j < kNumWeightsInReg; j += 4) { + sum_0 = transform_reg(j, sum_0); + sum_1 = transform_reg(j + 2, sum_1); + } + + const auto& transform_smem = [&](const uint32_t& j, const float2& sum) { + auto a = make_float2(fmaxf(accum[j], 0), fmaxf(accum[j + 1], 0)); + auto b = make_float2(ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j), + ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + j + 1)); + return __ffma2_rn(a, b, sum); + }; + + #pragma unroll + for (int j = kNumWeightsInReg; j < kNumHeads; j += 4) { + sum_0 = transform_smem(j, sum_0); + sum_1 = transform_smem(j + 2, sum_1); + } + + auto sum = __fadd2_rn(sum_0, sum_1); + float result = scale_kv * (sum.x + sum.y); + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + thread_idx] = result; + } + } + } else { + cutlass::arch::warpgroup_reg_dealloc(); + } + + // Free tensor memory + __syncthreads(); + if (is_umma_warp) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..4e4ff21d0746cff7bc7ecaf23a49278a2f5810cc --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm100_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,345 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm100; + +template +__device__ __forceinline__ +uint32_t get_swizzled_smem_offset(const uint32_t& offset, const uint32_t& lane_idx) { + // Calculate the index of the bank group to be written in the atom + const auto& bank_group_idx = offset + lane_idx * (kSwizzleMode / kSwizzleBase); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_N, kSwizzleMode / kSwizzleBase)` + // - new: `(BLOCK_N * kSwizzleMode / kSwizzleBase / kNumBankGroups, kNumBankGroups)` + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = (kSwizzleMode / kSwizzleBase) == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % (kSwizzleMode / kSwizzleBase); + + return row * 128 + col * kSwizzleBase; +} + +template +__global__ void __launch_bounds__(kNumMMAThreads + kNumCastAndReduceThreads, 1) +sm100_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 1000)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Configs + constexpr uint32_t kNumCastStages = 2; + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + constexpr auto kMajorA = cute::UMMA::Major::K; + constexpr auto kMajorB = cute::UMMA::Major::K; + DG_STATIC_ASSERT(kNumCastStages <= kNumStages, "Invalid cast stages"); + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Real tensor memory size and offsets + constexpr uint32_t kNumTmemCols = get_num_aligned_tmem_cols(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto full_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 2 + i); }); + auto empty_cast_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages * 3 + i); }); + auto tmem_full_barrier = barrier_start_ptr + kNumStages * 4; + + // Fill the tensor memory pointer + auto tmem_ptr_in_smem = reinterpret_cast(barrier_start_ptr + kNumStages * 4 + 1); + DG_STATIC_ASSERT(32 <= kNumTmemCols and kNumTmemCols <= 512, "Invalid tensor memory columns"); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + full_cast_barriers[i]->init(kNumCastAndReduceThreads); + empty_barriers[i]->init(1); + empty_cast_barriers[i]->init(1); + } + tmem_full_barrier->init(1); + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } else if (warp_idx == 2) { + // Allocate tensor memory + cute::TMEM::Allocator1Sm().allocate(kNumTmemCols, tmem_ptr_in_smem); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + + // Dispatch warps into different roles + if (warp_idx < kNumMMAThreads / 32) { + // TMA load warp + if (warp_idx == 0 and cute::elect_one_sync()) { + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + } + + // MMA issue warp + if (warp_idx == 1) { + // Make instruction descriptor + constexpr uint32_t UMMA_M = BLOCK_M; + constexpr uint32_t UMMA_N = BLOCK_N; + constexpr uint32_t UMMA_K = 32 / sizeof(float); + constexpr uint32_t BLOCK_SWIZZLED_BK = kSwizzleBMode / sizeof(float); + using umma_t = cute::SM100_MMA_TF32_TS; + auto instr_desc = cute::UMMA::make_instr_desc(); + const auto& runtime_instr_desc = cute::UMMA::make_runtime_instr_desc(instr_desc); + + DG_STATIC_ASSERT(kNumStages <= 32, "Too many stages"); + auto b_desc = make_umma_desc(smem_b[0], 0, 0); + const uint32_t& b_desc_lo = lane_idx < kNumStages ? b_desc.lo + lane_idx * SMEM_B_SIZE_PER_STAGE / 16 : 0u; + + // Checks for MMA instructions + // NOTES: CUTLASS does not have such checks except the MMA traits, but we are not using these traits + DG_STATIC_ASSERT((UMMA_M == 64 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 128 and UMMA_N % 8 == 0 and 8 <= UMMA_N and UMMA_N <= 256) or + (UMMA_M == 256 and UMMA_N % 16 == 0 and 16 <= UMMA_N and UMMA_N <= 256), + "Invalid MMA instruction shape"); + + // Launch MMAs + // We can not unroll this part + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + const auto& cast_stage_idx = s % kNumCastStages; + full_cast_barriers[cast_stage_idx]->wait((s / kNumCastStages) & 1); + tcgen05_after_thread_sync(); + + // Issue UMMA + const auto& b_desc_base_lo = __shfl_sync(0xffffffff, b_desc_lo, static_cast(stage_idx)); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / UMMA_K; ++ k) { + const uint32_t& atom_idx = (k * UMMA_K) / BLOCK_SWIZZLED_BK; + const uint32_t& in_atom_idx = (k * UMMA_K) % BLOCK_SWIZZLED_BK; + const uint32_t& offset = atom_idx * BLOCK_N * BLOCK_SWIZZLED_BK; + b_desc.lo = advance_umma_desc_lo(b_desc_base_lo, offset, in_atom_idx); + umma_t::fma(BLOCK_K * cast_stage_idx + k * UMMA_K, b_desc, BLOCK_K * kNumCastStages, s > 0 or k > 0, runtime_instr_desc); + } + + // Commit + cutlass::arch::umma_arrive(reinterpret_cast(empty_cast_barriers[cast_stage_idx])); + cutlass::arch::umma_arrive(reinterpret_cast(empty_barriers[stage_idx])); + } + + // Commit to epilogue threads + cutlass::arch::umma_arrive(reinterpret_cast(tmem_full_barrier)); + } + + // TMA checks + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(float); + DG_STATIC_ASSERT(kSwizzleCDMode > 0, "TMA D must be swizzled"); + DG_STATIC_ASSERT(BLOCK_N % kNumElemsPerBankGroup == 0, "Invalid swizzling"); + + // Only support layout F (M = 64) and D (M = 128) + DG_STATIC_ASSERT(BLOCK_M == 64 or BLOCK_M == 128, "Invalid block M"); + + // Wait UMMA arrival + tmem_full_barrier->wait(0); + tcgen05_after_thread_sync(); + + // Load from tensor memory into registers, and write shared memory with STSM + DG_STATIC_ASSERT(kNumMMAThreads == 128, "Epilogue threads not enough"); + + // Store into shared memory + #pragma unroll + for (uint32_t i = 0; i < BLOCK_N / kNumElemsPerBankGroup; ++ i) { + // Source and destination memory address + uint32_t tmem_addr = BLOCK_K * kNumCastStages + i * kNumElemsPerBankGroup; + auto smem_ptr = reinterpret_cast(smem_cd) + // Base pointer + warp_idx * BLOCK_M / 4 * kSwizzleCDMode + // Warp offset + get_swizzled_smem_offset(i, lane_idx); // In-atom offset + + // Load from tensor memory, store into shared memory + uint32_t values[kNumElemsPerBankGroup]; + DG_STATIC_ASSERT(kNumElemsPerBankGroup == 4, "Invalid type"); + cute::SM100_TMEM_LOAD_32dp32b4x::copy(tmem_addr, + values[0], values[1], values[2], values[3]); + cutlass::arch::fence_view_async_tmem_load(); + if (BLOCK_M == 128 or (BLOCK_M == 64 and lane_idx < 16)) + st_shared(smem_ptr, values[0], values[1], values[2], values[3]); + if constexpr (BLOCK_M == 64) + __syncwarp(); + } + + // Synchronize all threads and issue TMA + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumMMAThreads, 0); + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + + // Deallocate tensor memory by warp 1 + // NOTES: warp 0 is waiting TMA store + if (warp_idx == 1) + cute::TMEM::Allocator1Sm().free(0, kNumTmemCols); + } else { + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(kNumCastAndReduceThreads == 128, "Invalid cast-and-reduce threads"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + const uint32_t sub_warp_idx = warp_idx - kNumMMAThreads / 32; + + // TODO: make even larger block K + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + + // Launch reductions + float2 sum[2] = {float2{0, 0}, float2{0, 0}}; + #pragma unroll kNumStages + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Load from shared memory into tensor memory using movement shape `.16x256b` (shared memory part is 128b) + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + const auto& smem_base_ptr = reinterpret_cast(smem_a[stage_idx]) + // Base pointer + sub_warp_idx * BLOCK_M_PER_WARP * kSwizzleAMode; // Warp offset + + // 4 lanes shared a bank group + uint32_t uint32_values[2][kNumLoads]; + DG_STATIC_ASSERT(kNumLoads % 2 == 0, "Invalid number of loads"); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; i += 2) { + auto smem_ptr = smem_base_ptr + get_swizzled_smem_offset(i + lane_idx / 16, lane_idx % 16); + sm90::SM90_U32x4_LDSM_N::copy(uint32_values[0][i + 0], uint32_values[1][i + 0], + uint32_values[0][i + 1], uint32_values[1][i + 1], + smem_ptr); + } + + // Wait tensor memory empty + const auto& cast_stage_idx = s % kNumCastStages; + empty_cast_barriers[cast_stage_idx]->wait(((s / kNumCastStages) & 1) ^ 1); + + // Cast, reduce and store into tensor memory + float2 fp32x2_values[2][kNumLoads]; + const auto& upper_view = reinterpret_cast(&fp32x2_values[0]); + const auto& lower_view = reinterpret_cast(&fp32x2_values[1]); + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + fp32x2_values[u][i] = __bfloat1622float2(*reinterpret_cast(&uint32_values[u][i])); + sum[u] = __ffma2_rn(fp32x2_values[u][i], fp32x2_values[u][i], sum[u]); + } + + // Store upper and lower part at the same time + const auto idx_0 = i * 2, idx_1 = i * 2 + 1; + cute::SM100_TMEM_STORE_16dp256b1x::copy( + upper_view[idx_0], upper_view[idx_1], + lower_view[idx_0], lower_view[idx_1], + cast_stage_idx * BLOCK_K + i * 8); + } + cutlass::arch::fence_view_async_tmem_store(); + + // Arrive for issuing MMAs + tcgen05_before_thread_sync(); + full_cast_barriers[cast_stage_idx]->arrive(); + } + + // Intra-warp reduction and write back + #pragma unroll + for (uint32_t u = 0; u < 2; ++ u) { + const auto& reduced_sum = warp_reduce_sum<4>(sum[u].x + sum[u].y); + const auto& m_idx = m_block_idx * BLOCK_M + sub_warp_idx * BLOCK_M_PER_WARP + lane_idx / 4 + u * 8; + if (lane_idx % 4 == 0 and m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum; + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_100f"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7a77e4e8fbbbffa56e8c8632ade7ae7938b30ee9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bf16_gemm.cuh @@ -0,0 +1,381 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bf16_gemm_impl(int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Enlarge `BLOCK_K` for some cases + // NOTES: this is for reducing the `warpgroup_wait<0>()` overhead + constexpr uint32_t kDoMergeStages = + kNumStages_ >= 10 and + kGemmType == GemmType::Normal and + kMajorA == cute::UMMA::Major::K and kMajorB == cute::UMMA::Major::K and + kNumMathThreads == 128; + // Ensure there are at least `kNumMinStages` stages after merge + constexpr uint32_t kNumMinStages = 5; + constexpr uint32_t kNumStagesPerMerge = kDoMergeStages ? kNumStages_ / kNumMinStages : 1; + constexpr uint32_t BLOCK_K = BLOCK_K_ * kNumStagesPerMerge; + constexpr uint32_t kNumStages = kNumStages_ / kNumStagesPerMerge; + + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(cd_dtype_t)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0 and SMEM_A_SIZE_PER_STAGE % 1024 == 0 and SMEM_B_SIZE_PER_STAGE % 1024 == 0, + "Shared memory of A/B/D must be aligned to 1024 bytes"); + + // D/A/B shared memory + auto smem_d = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 48; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 224; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + DG_STATIC_ASSERT(kNumTMAThreads >= 128, "Need at least 128 threads for TMA warp-group"); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + + const auto m_idx = scheduler.template get_global_idx(shape_m, BLOCK_M, m_block_idx); + const auto n_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::K), IndexType::MN>(shape_n, BLOCK_N, n_block_idx, m_block_idx); + + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous or kMajorA == cute::UMMA::Major::K, "Invalid major"); + uint32_t k_a_idx = scheduler.template get_global_idx<(kMajorA == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + uint32_t k_b_idx = scheduler.template get_global_idx<(kMajorB == cute::UMMA::Major::MN), IndexType::K> ( + shape_k, BLOCK_K, k_block_idx, m_block_idx); + + // Issue TMAs + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + if constexpr (kMajorA == cute::UMMA::Major::K) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_a_idx, m_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorA == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], m_idx, k_a_idx, num_tma_multicast_a, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::K) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_b_idx, n_idx, num_tma_multicast_b, batch_idx); + if constexpr (kMajorB == cute::UMMA::Major::MN) + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], n_idx, k_b_idx, num_tma_multicast_b, batch_idx); + + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + + // Merged stages only happens in NT normal GEMM cases + constexpr uint32_t BLOCK_ATOM_K = BLOCK_K / kNumStagesPerMerge; + auto a_desc = make_gmma_desc(smem_a[0], math_wg_idx * WGMMA::M, 0); + auto b_desc = make_gmma_desc(smem_b[0], 0, 0); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= 64 or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + // TODO: remove some useless computation for unaligned Ms + const auto& num_total_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + const uint32_t& atom_k_idx = k * WGMMA::K / BLOCK_ATOM_K; + a_desc.reg32_[0] = advance_gmma_desc_lo( + a_desc_base_lo, local_idx * WAVE_BLOCK_M, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_M * BLOCK_ATOM_K); + b_desc.reg32_[0] = advance_gmma_desc_lo( + b_desc_base_lo, 0, (k * WGMMA::K) % BLOCK_ATOM_K, atom_k_idx * BLOCK_N * BLOCK_ATOM_K); + WGMMA::wgmma(a_desc, b_desc, shifted_accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M); ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + if constexpr (cute::is_same_v) { + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(kSwizzleDMode > 0, "Invalid swizzling type"); + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + } else { + // Use `st.shared` if STSM is not available + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = accum + WGMMA::kNumAccum * local_idx; + auto smem_d_0 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 0) * BLOCK_N + (lane_idx % 4) * 2); + auto smem_d_1 = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx / 4 + 8) * BLOCK_N + (lane_idx % 4) * 2); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, make_float2(shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1])); + st_shared(smem_d_1 + i * 4, make_float2(shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3])); + } + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 0); + + // Use TMA store to write back to global memory + const auto m_idx = scheduler.template get_global_idx<(not is_m_grouped_contiguous(kGemmType)), IndexType::MN>(shape_m, BLOCK_M, m_block_idx); + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, + m_idx, scheduler.current_group_idx); + } else { + using cute_tma_t = cute::conditional_t; + cute_tma_t::copy(&tensor_map_cd, smem_ptr, + n_block_idx * BLOCK_N + in_block_n_offset, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh new file mode 100644 index 0000000000000000000000000000000000000000..191a4fe2c4ccf66b0743affedcbfd17950e2618f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_bmk_bnk_mn.cuh @@ -0,0 +1,174 @@ +#pragma once + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_bmn_bnk_mn_gemm_impl(const uint32_t shape_s, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + float *d) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Types + using WGMMA = typename BF16MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Shared memory + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_bfloat16); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_bfloat16); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + DG_STATIC_ASSERT(BLOCK_M == 128, "Invalid block M"); + DG_STATIC_ASSERT(kNumTMAThreads == 128, "Invalid number of TMA threads"); + DG_STATIC_ASSERT(kNumMathThreads == 256, "Invalid number of math threads"); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + // Fill shared memory pointers + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_bfloat16*>(smem_buffer + (kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 232; + + // Block indices + const uint32_t num_n_blocks = ceil_div(SHAPE_N, BLOCK_N); + const uint32_t num_mn_blocks = num_n_blocks * ceil_div(SHAPE_M, BLOCK_M); + const uint32_t mn_block_idx = blockIdx.x % num_mn_blocks; + const uint32_t sk_block_idx = blockIdx.x / num_mn_blocks; + const uint32_t n_block_idx = mn_block_idx % num_n_blocks; + const uint32_t m_block_idx = mn_block_idx / num_n_blocks; + const uint32_t num_total_stages = cute::min(kSplitFactor, shape_s * (SHAPE_K / BLOCK_K) - sk_block_idx * kSplitFactor); + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + // Persistently schedule over blocks + #pragma unroll + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait((s / kNumStages + 1) & 1); + + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& sk_idx = (sk_block_idx * kSplitFactor + s) * BLOCK_K; + const uint32_t& k_idx = sk_idx % SHAPE_K; + const uint32_t& s_idx = sk_idx / SHAPE_K; + + constexpr uint32_t kSwizzle = BLOCK_K * sizeof(nv_bfloat16); + tma_copy( + &tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_block_idx * BLOCK_M + s_idx * SHAPE_M, 1); + tma_copy( + &tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_block_idx * BLOCK_N + s_idx * SHAPE_N, 1); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + float accum[WGMMA::kNumAccum] = {0}; + + // Launch MMAs + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrivals + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + (math_wg_idx * WGMMA::M) * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, 1); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + empty_barriers[stage_idx]->arrive(); + } + + const auto& row = m_block_idx * BLOCK_M + warp_idx * 16 + lane_idx / 4; + const auto& col = n_block_idx * BLOCK_N + (lane_idx % 4) * 2; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + if (col + i * 8 >= SHAPE_N) + break; + if (row < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 0) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 0], accum[i * 4 + 1])); + } + if (row + 8 < SHAPE_M) { + atomicAdd(reinterpret_cast(d + (row + 8) * SHAPE_N + col + i * 8), + make_float2(accum[i * 4 + 2], accum[i * 4 + 3])); + } + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cdd28fcb59d3b038c84c007ef1da1477d7ca263a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d1d.cuh @@ -0,0 +1,349 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d1d_impl(__nv_fp8_e4m3* gmem_a_ptr, __nv_fp8_e4m3* gmem_b_ptr, + int* grouped_layout, + cute::TmaDescriptor* tensor_map_buffer, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_b_base, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfb, + const __grid_constant__ cute::TmaDescriptor tensor_map_cd) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid Threads"); + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(cute::is_same_v, "Invalid C/D data dtype"); + DG_STATIC_ASSERT(kGemmType == GemmType::Normal or kGemmType == GemmType::KGroupedContiguous, "Invalid GEMM type"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr uint32_t SMEM_TENSOR_MAP_SIZE = (kGemmType == GemmType::KGroupedContiguous ? sizeof(cute::TmaDescriptor) * 4 : 0); + static constexpr uint32_t SMEM_D_SIZE = BLOCK_M * BLOCK_N * sizeof(float); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t SMEM_SFB_SIZE_PER_STAGE = BLOCK_N * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFB_SIZE_PER_STAGE = constexpr_align(SMEM_SFB_SIZE_PER_STAGE, 128u); + DG_STATIC_ASSERT(SMEM_SFA_SIZE_PER_STAGE % 128 == 0, "Invalid TMA alignment"); + + // Configs + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = threadIdx.x % 32; + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a_base); + cute::prefetch_tma_descriptor(&tensor_map_b_base); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_sfb); + cute::prefetch_tma_descriptor(&tensor_map_cd); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Tensor maps on shared and global memory + auto smem_tensor_map_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * i); + }); + auto smem_tensor_map_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + static_cast(sizeof(cute::TmaDescriptor)) * (2 + i)); + }); + auto gmem_tensor_map_a = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + i; }); + auto gmem_tensor_map_b = PatternVisitor([=](const uint32_t& i) { return tensor_map_buffer + blockIdx.x * 4 + 2 + i; }); + + // Data on shared memory + auto smem_d = reinterpret_cast(smem_buffer + SMEM_TENSOR_MAP_SIZE); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + (SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + constexpr auto SMEM_SF_OFFSET = SMEM_TENSOR_MAP_SIZE + SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + i * SMEM_SFA_SIZE_PER_STAGE)); + }); + auto smem_sfb = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_SF_OFFSET + kNumStages * SMEM_SFA_SIZE_PER_STAGE + i * ALIGNED_SMEM_SFB_SIZE_PER_STAGE)); + }); + + // Barriers on shared memory + constexpr auto SMEM_BARRIER_OFFSET = SMEM_SF_OFFSET + kNumStages * (SMEM_SFA_SIZE_PER_STAGE + ALIGNED_SMEM_SFB_SIZE_PER_STAGE); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + i * static_cast(sizeof(Barrier)))); + }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_BARRIER_OFFSET + (kNumStages + i) * static_cast(sizeof(Barrier)))); + }); + + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // Load tensormap A/B to shared memory + if constexpr (kGemmType == GemmType::KGroupedContiguous) { + *smem_tensor_map_a[0] = tensor_map_a_base; + *smem_tensor_map_a[1] = tensor_map_a_base; + *smem_tensor_map_b[0] = tensor_map_b_base; + *smem_tensor_map_b[1] = tensor_map_b_base; + } + + // Initialize barriers + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Pipeline unroll control + constexpr uint32_t kNumPipelineUnrolls = (kGemmType == GemmType::KGroupedContiguous ? 0 : kNumStages); + + // Register reconfigurations (more math registers are needed with unrolling) + constexpr uint32_t kNumTMARegisters = (kNumPipelineUnrolls == 0 ? 40 : 24); + constexpr uint32_t kNumMathRegisters = (kNumPipelineUnrolls == 0 ? 232 : 240); + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // TMA and MMA pipeline + const auto& get_pipeline = [=](const uint32_t& iter_idx) -> cute::tuple { + return {iter_idx % kNumStages, (iter_idx / kNumStages) & 1}; // Pipeline stage and phase + }; + uint32_t iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + const cute::TmaDescriptor* current_tensor_map_a = &tensor_map_a_base; + const cute::TmaDescriptor* current_tensor_map_b = &tensor_map_b_base; + uint32_t last_group_idx = kNumGroups, sum_k = 0; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + const uint32_t& num_k_blocks = ceil_div(scheduler.current_shape_k, BLOCK_K); + const uint32_t& m_idx = m_block_idx * BLOCK_M; + const uint32_t& n_idx = n_block_idx * BLOCK_N; + + if (kGemmType == GemmType::KGroupedContiguous and last_group_idx != scheduler.current_group_idx) { + const uint32_t& stage_idx = scheduler.current_num_valid_groups & 1; + const uint32_t& next_stage_idx = stage_idx ^ 1; + last_group_idx = scheduler.current_group_idx; + + // Prepare next tensor map + sum_k += scheduler.current_shape_k; + if (scheduler.next_group_idx < kNumGroups) { + tensor_map_replace_global_addr_in_smem(smem_tensor_map_a[next_stage_idx], gmem_a_ptr + static_cast(sum_k) * shape_m); + tensor_map_replace_global_addr_in_smem(smem_tensor_map_b[next_stage_idx], gmem_b_ptr + static_cast(sum_k) * shape_n); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_a[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + tensor_map_replace_global_inner_dim_stride_in_smem(smem_tensor_map_b[next_stage_idx], scheduler.next_shape_k, scheduler.next_shape_k); + *(gmem_tensor_map_a[next_stage_idx]) = *(smem_tensor_map_a[next_stage_idx]); + *(gmem_tensor_map_b[next_stage_idx]) = *(smem_tensor_map_b[next_stage_idx]); + tensor_map_release_cta(); + } + + // Get current tensor map + if (scheduler.current_num_valid_groups > 0) { + tensor_map_acquire_cta(gmem_tensor_map_a[stage_idx]); + tensor_map_acquire_cta(gmem_tensor_map_b[stage_idx]); + current_tensor_map_a = gmem_tensor_map_a[stage_idx]; + current_tensor_map_b = gmem_tensor_map_b[stage_idx]; + } + } + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t& k_idx = k_block_idx * BLOCK_K; + const uint32_t& sf_k_idx = scheduler.current_sf_k_cumsum + k_block_idx; + tma_copy(&tensor_map_sfa, &full_barrier, smem_sfa[stage_idx], m_idx, sf_k_idx, num_tma_multicast_a); + tma_copy(&tensor_map_sfb, &full_barrier, smem_sfb[stage_idx], n_idx, sf_k_idx, num_tma_multicast_b); + tma_copy(current_tensor_map_a, &full_barrier, smem_a[stage_idx], k_idx, m_idx, num_tma_multicast_a); + tma_copy(current_tensor_map_b, &full_barrier, smem_b[stage_idx], k_idx, n_idx, num_tma_multicast_b); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE + SMEM_SFB_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + #pragma unroll + for (uint32_t s = 0; s < kNumStages; ++ s) { + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto row_idx = lane_idx / 4, col_idx = lane_idx % 4; + const auto r_0 = warp_idx * 16 + row_idx, r_1 = r_0 + 8; + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Accumulation for WGMMA or CUDA promotion + DG_STATIC_ASSERT(BLOCK_M == WGMMA::M * (BLOCK_M <= 64 ? 1 : 2), "Invalid block sizes"); + const uint32_t& current_shape_k = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_shape_k : shape_k); + const uint32_t& current_group_idx = (kGemmType == GemmType::KGroupedContiguous ? scheduler.current_group_idx : 0); + const uint32_t& num_k_blocks = ceil_div(current_shape_k, BLOCK_K); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum] = {0}; + float2 scales_b[WGMMA::kNumAccum / 4]; + + // Empty barrier arrival + auto empty_barrier_arrive = [&](uint32_t s) { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[s]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[s]->arrive(target_cta) : void(); + } + }; + + #pragma unroll kNumPipelineUnrolls + for (uint32_t k_block_idx = 0; k_block_idx < num_k_blocks; ++ k_block_idx) { + // Wait TMA arrivals + CUTE_TIE_DECL(get_pipeline(iter_idx ++), stage_idx, phase); + full_barriers[stage_idx]->wait(phase); + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = ld_shared(smem_sfa[stage_idx] + r_0); + auto scale_a_1 = ld_shared(smem_sfa[stage_idx] + r_1); + + // Read B scales + #pragma unroll + for (int i = 0; i < WGMMA::kNumAccum / 4; ++i) + scales_b[i] = ld_shared(reinterpret_cast(smem_sfb[stage_idx] + i * 8 + col_idx * 2)); + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_a[stage_idx] + math_wg_idx * WGMMA::M * BLOCK_K + k * WGMMA::K, 1); + auto desc_b = make_smem_desc(smem_b[stage_idx] + k * WGMMA::K, 1); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival + empty_barrier_arrive(stage_idx); + + // Promote with scales + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + const float &scale_b_0 = scales_b[i].x; + const float &scale_b_1 = scales_b[i].y; + final_accum[i * 4 + 0] += scale_a_0 * scale_b_0 * accum[i * 4 + 0]; + final_accum[i * 4 + 1] += scale_a_0 * scale_b_1 * accum[i * 4 + 1]; + final_accum[i * 4 + 2] += scale_a_1 * scale_b_0 * accum[i * 4 + 2]; + final_accum[i * 4 + 3] += scale_a_1 * scale_b_1 * accum[i * 4 + 3]; + } + } + + // Flush previous stores + if (warp_idx % 4 == 0 and cute::elect_one_sync()) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Store to D shared memory + const auto& smem_d_0 = reinterpret_cast(smem_d + r_0 * BLOCK_N + col_idx * 2); + const auto& smem_d_1 = reinterpret_cast(smem_d + r_1 * BLOCK_N + col_idx * 2); + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + st_shared(smem_d_0 + i * 4, {final_accum[i * 4 + 0], final_accum[i * 4 + 1]}); + st_shared(smem_d_1 + i * 4, {final_accum[i * 4 + 2], final_accum[i * 4 + 3]}); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, math_wg_idx); + + // Use TMA store to write back to global memory + if (warp_idx % 4 == 0 and cute::elect_one_sync()) { + cute::SM90_TMA_REDUCE_ADD_2D::copy( + &tensor_map_cd, smem_d_0, n_block_idx * BLOCK_N, + current_group_idx * shape_m + m_block_idx * BLOCK_M + r_0); + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh new file mode 100644 index 0000000000000000000000000000000000000000..9247304cdd17d8e2c3a5cdb31c78c191ae6b76ec --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_gemm_1d2d.cuh @@ -0,0 +1,440 @@ +#pragma once + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +#include +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ void dispatch_num_former_iters(uint32_t num_former_iters, const func_t& func) { + if (num_former_iters == kNumFormerIters) { + func(cute::Int{}); + return; + } + + if constexpr (kNumFormerIters + kGap <= kEnd) + dispatch_num_former_iters(num_former_iters, func); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) void +sm90_fp8_gemm_1d2d_impl(float* sfb, int* grouped_layout, + uint32_t shape_m, uint32_t shape_n, uint32_t shape_k, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + const __grid_constant__ cute::TmaDescriptor tensor_map_sfa) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + // Scaling checks + DG_STATIC_ASSERT(BLOCK_K == 128, "Only support per-128-channel FP8 scaling"); + DG_STATIC_ASSERT(constexpr_ceil_div(BLOCK_N, BLOCK_K) == 1 or (constexpr_gcd(BLOCK_N, BLOCK_K) == BLOCK_N - BLOCK_K), "Too much B scales in a single block"); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + DG_STATIC_ASSERT(BLOCK_M % WGMMA::M == 0 or BLOCK_M < WGMMA::M, "Invalid block size"); + + // Overwrite shape constants if the compiler gives + shape_m = SHAPE_M != 0 ? SHAPE_M : shape_m; + shape_n = SHAPE_N != 0 ? SHAPE_N : shape_n; + shape_k = SHAPE_K != 0 ? SHAPE_K : shape_k; + + // Shared memory + static constexpr bool kMustUseUniformedScaleB = (BLOCK_K % BLOCK_N == 0); + static constexpr uint32_t SMEM_D_SIZE = constexpr_align(BLOCK_M * BLOCK_N * static_cast(sizeof(__nv_bfloat16)), 1024u); + static constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_SFA_SIZE_PER_STAGE = BLOCK_M * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_SFA_SIZE_PER_STAGE = constexpr_align(SMEM_SFA_SIZE_PER_STAGE, 128u); + const uint32_t& shape_k_scales = ceil_div(shape_k, BLOCK_K); + const uint32_t& shape_n_sfb = ceil_div(shape_n, BLOCK_K); + const uint32_t& smem_sfb_size = align(shape_k_scales * (kMustUseUniformedScaleB ? 1 : 2) * sizeof(float), sizeof(Barrier)); + + // NOTES: Make sure we have enough shared memory for WGMMA padding + static constexpr uint32_t WGMMA_A_SIZE_PER_STAGE = WGMMA::M * BLOCK_K * sizeof(__nv_fp8_e4m3); + DG_STATIC_ASSERT(WGMMA_A_SIZE_PER_STAGE <= SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE * kNumStages, "Memory Out of bound for WGMMA"); + + // Configs + const uint32_t num_total_k_blocks = ceil_div(shape_k, BLOCK_K); + const uint32_t warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const uint32_t lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors at the very beginning + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_sfa); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + __syncwarp(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_D_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + // Data on shared memory + auto smem_d = reinterpret_cast<__nv_bfloat16*>(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + i * SMEM_A_SIZE_PER_STAGE); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_D_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE); + }); + constexpr uint32_t SMEM_SF_OFFSET = SMEM_D_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE); + auto smem_sfa = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + i * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + }); + auto smem_sfb = reinterpret_cast(smem_buffer + SMEM_SF_OFFSET + kNumStages * ALIGNED_SMEM_SFA_SIZE_PER_STAGE); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(reinterpret_cast(smem_sfb) + smem_sfb_size); + auto full_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + i; }); + auto empty_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_start_ptr + kNumStages + i; }); + + // Initialize barriers + DG_STATIC_ASSERT(kNumTMAMulticast <= 32, "Too many TMA multicast"); + if (warp_idx == kNumMathThreads / 32 + 1 and cute::elect_one_sync()) { + // NOTES: we always use `lane_idx` to arrive for the `lane_idx`-th CTA in the cluster, + // even with TMA multicast disabled, we want to make the behavior aligned + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(kNumTMAMulticast * kNumMathThreads / 32); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + + // Synchronize all threads to make barrier visible in normal memory model + (kNumTMAMulticast > 1) ? cute::cluster_sync() : __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = kNumMathThreads == 128 ? 248 : 232; + + // Block scheduler + uint32_t m_block_idx, n_block_idx; + auto scheduler = Scheduler(shape_m, shape_n, shape_k, grouped_layout); + + // Pipeline and TMA phases + uint32_t stage_idx = 0, phase = 0; + auto advance_pipeline = [&](uint32_t& k_block_idx) { + ++ k_block_idx; + + // Flip phases only if reach the next first stage + stage_idx = stage_idx == kNumStages - 1 ? 0 : stage_idx + 1; + phase ^= stage_idx == 0; + }; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // NOTES: only one thread (or warp) will be used + // We use the third warp, as warp 0/1 may be doing WGMMA with `BLOCK_M == 32` + if (warp_idx == kNumMathThreads / 32 + 2 and cute::elect_one_sync()) { + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Assign TMA multicast number into A and B + // NOTES: there may be additional odd rows/columns or cases where multicast is not possible. + const bool is_tma_multicast_valid = scheduler.is_tma_multicast_valid(m_block_idx); + const uint32_t num_tma_multicast_a = (kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + const uint32_t num_tma_multicast_b = (not kIsTMAMulticastOnA and is_tma_multicast_valid) ? kNumTMAMulticast : 1; + DG_STATIC_ASSERT(kNumTMAMulticast <= 2, "Scheduler does not support > 2 TMA multicast"); + + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + // Wait consumer release + empty_barriers[stage_idx]->wait(phase ^ 1); + + // Issue TMA A + constexpr bool kIsBatchedMM = (kGemmType == GemmType::Batched); + const uint32_t batch_idx = (kIsBatchedMM ? scheduler.current_group_idx : 0); + + constexpr bool kWithGroupOffsetA = kGemmType == GemmType::MGroupedMasked; + auto& full_barrier = *full_barriers[stage_idx]; + const uint32_t k_idx = k_block_idx * BLOCK_K; + tma_copy(&tensor_map_a, &full_barrier, + smem_a[stage_idx], k_idx, scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx), + num_tma_multicast_a, batch_idx); + tma_copy(&tensor_map_sfa, &full_barrier, + smem_sfa[stage_idx], m_block_idx * BLOCK_M, scheduler.template get_global_idx(shape_k_scales, 1, k_block_idx), + num_tma_multicast_a); + + // Issue TMA B + tma_copy(&tensor_map_b, &full_barrier, + smem_b[stage_idx], k_idx, scheduler.get_global_idx(shape_n, BLOCK_N, n_block_idx, m_block_idx), + num_tma_multicast_b, batch_idx); + full_barrier.arrive_and_expect_tx(SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE + SMEM_SFA_SIZE_PER_STAGE); + } + } + + // To safely deconstruct distributed shared barriers, we need another round of empty waits + if constexpr (kNumTMAMulticast > 1) { + for (uint32_t i = 0; i < kNumStages; advance_pipeline(i)) + empty_barriers[stage_idx]->wait(phase ^ 1); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto math_wg_idx = __shfl_sync(0xffffffff, threadIdx.x / 128, 0); + const auto r_0 = warp_idx * 16 + lane_idx / 4, r_1 = r_0 + 8; + + auto a_desc = make_smem_desc(smem_a[0] + math_wg_idx * WGMMA::M * BLOCK_K, 1); + auto b_desc = make_smem_desc(smem_b[0], 1); + const uint32_t a_desc_lo = __shfl_sync(0xffffffff, a_desc.reg32_[0], 0); + const uint32_t b_desc_lo = __shfl_sync(0xffffffff, b_desc.reg32_[0], 0); + + // Persistently schedule over blocks + while (scheduler.get_next_block(m_block_idx, n_block_idx)) { + // Decide the number of scales B to load + DG_TRAP_ONLY_DEVICE_ASSERT(shape_n % 8 == 0); + uint32_t num_former_iters = BLOCK_N / 8, num_full_iters = num_former_iters; + if constexpr (not kMustUseUniformedScaleB) { + num_former_iters = min(BLOCK_N, BLOCK_K - n_block_idx * BLOCK_N % BLOCK_K) / 8; + num_full_iters = min(shape_n - n_block_idx * BLOCK_N, BLOCK_N) / 8; + } + uint32_t num_sfb = shape_k_scales * (num_former_iters >= num_full_iters ? 1 : 2); + + // Load B scales with math warp-groups + // NOTES: except the first warp, we want to overlap loading B scales with TMA stores between tasks + if (threadIdx.x >= 32) { + auto previous_group_offset = scheduler.template get_global_idx(shape_n_sfb * shape_k_scales, 0, 0, m_block_idx); + const uint32_t stride_n_sfb = kMajorSFB == cute::UMMA::Major::MN ? 1 : shape_k_scales; + const uint32_t stride_k_sfb = kMajorSFB == cute::UMMA::Major::MN ? shape_n_sfb : 1; + auto local_sfb = sfb + previous_group_offset + ((n_block_idx * BLOCK_N) / BLOCK_K) * stride_n_sfb; + + #pragma unroll + for (uint32_t i = threadIdx.x - 32; i < num_sfb; i += kNumMathThreads - 32) + st_shared(smem_sfb + i, __ldg(i < shape_k_scales ? local_sfb + i * stride_k_sfb : local_sfb + (i - shape_k_scales) * stride_k_sfb + stride_n_sfb)); + } + cutlass::arch::NamedBarrier::sync(kNumMathThreads, 0); + + // Accumulation for WGMMA or CUDA promotion + constexpr uint32_t WAVE_BLOCK_M = BLOCK_M <= WGMMA::M ? BLOCK_M : WGMMA::M * 2; + DG_STATIC_ASSERT(BLOCK_M % WAVE_BLOCK_M == 0, "Invalid block sizes"); + float accum[WGMMA::kNumAccum], final_accum[WGMMA::kNumAccum * (BLOCK_M / WAVE_BLOCK_M)] = {0}; + + // Pick threads whose WGMMA results are to be stored in shared memory + DG_STATIC_ASSERT(BLOCK_M >= 64 or kNumMathThreads == 128, "Only one math warp group for `BLOCK_M < 64`"); + constexpr uint32_t kNumWGMMAStoreThreads = WAVE_BLOCK_M * (128 / WGMMA::M); + const bool do_wgmma_store = BLOCK_M >= WGMMA::M or warp_idx < kNumWGMMAStoreThreads / 32; + + // Empty barrier arrival + auto empty_barrier_arrive = [&]() { + if constexpr (kNumTMAMulticast == 1) { + lane_idx == 0 ? empty_barriers[stage_idx]->arrive() : void(); + } else { + auto target_cta = scheduler.is_peer_cta_alive ? lane_idx : cute::block_rank_in_cluster(); + lane_idx < kNumTMAMulticast ? empty_barriers[stage_idx]->arrive(target_cta) : void(); + } + }; + + // Skip useless computations + if (scheduler.is_computation_valid(m_block_idx, math_wg_idx * WGMMA::M)) { + // The compiler must know the dynamic variable `num_former_iters`'s real value + constexpr bool kShouldOptimize = BLOCK_K / constexpr_gcd(BLOCK_K, BLOCK_N) <= 4 and not kMustUseUniformedScaleB; + constexpr uint32_t kGap = constexpr_gcd(BLOCK_K, BLOCK_N) / 8; + constexpr uint32_t kEnd = kShouldOptimize ? BLOCK_K / 8 : 0; + + // Dispatch `num_former_iters` and launch MMAs + dispatch_num_former_iters<0, kGap, kEnd>(kShouldOptimize ? num_former_iters : 0, [&](auto _) { + #pragma unroll 8 + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + const auto& a_desc_base_lo = a_desc_lo + stage_idx * (SMEM_A_SIZE_PER_STAGE / 16); + const auto& b_desc_base_lo = b_desc_lo + stage_idx * (SMEM_B_SIZE_PER_STAGE / 16); + + // Read B scales + float scale_b_0 = ld_shared(smem_sfb + k_block_idx), scale_b_1; + // NOTES: even some blocks do not need to read the second row, but we still load one to align with other blocks + if constexpr (not kMustUseUniformedScaleB) + scale_b_1 = ld_shared(smem_sfb + k_block_idx + shape_k_scales); + + // Wait TMA arrivals + full_barriers[stage_idx]->wait(phase); + + // TODO: remove some useless computation for unaligned Ms + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + + // Read A scales + // NOTES: all shared memory read must be prior to `warpgroup_arrive` to avoid next scheduled block polluting the results + auto scale_a_0 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_0 + m_offset) : 0; + auto scale_a_1 = do_wgmma_store ? ld_shared(smem_sfa[stage_idx] + r_1 + m_offset) : 0; + + // Commit WGMMA instructions + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < BLOCK_K / WGMMA::K; ++ k) { + a_desc.reg32_[0] = a_desc_base_lo + (m_offset * BLOCK_K + k * WGMMA::K) / 16; + b_desc.reg32_[0] = b_desc_base_lo + k * WGMMA::K / 16; + WGMMA::wgmma(a_desc, b_desc, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Notify barrier arrival at the last warpgroup wave + if (local_idx == BLOCK_M / WAVE_BLOCK_M - 1) + empty_barrier_arrive(); + + // Skip promotion for the unfilled parts + if (not do_wgmma_store) + continue; + + // Promote with scales + // NOTES: making it as predicates is very important for performance, comparing to two loops + float scale_0_0 = scale_a_0 * scale_b_0, scale_1_0 = scale_a_1 * scale_b_0; + float scale_0_1, scale_1_1; + if constexpr (not kMustUseUniformedScaleB) + scale_0_1 = scale_a_0 * scale_b_1, scale_1_1 = scale_a_1 * scale_b_1; + + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // NOTES: for unrolled `num_former_iters` cases, we expect the compiler to automatically make it a constant + const bool& predicate = kMustUseUniformedScaleB or i < num_former_iters; + shifted_accum[i * 4 + 0] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 0]; + shifted_accum[i * 4 + 1] += (predicate ? scale_0_0 : scale_0_1) * accum[i * 4 + 1]; + shifted_accum[i * 4 + 2] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 2]; + shifted_accum[i * 4 + 3] += (predicate ? scale_1_0 : scale_1_1) * accum[i * 4 + 3]; + } + } + } + }); + } else { + #pragma unroll + for (uint32_t k_block_idx = 0; k_block_idx < num_total_k_blocks; advance_pipeline(k_block_idx)) { + full_barriers[stage_idx]->wait(phase); + empty_barrier_arrive(); + } + } + + // TMA checks + constexpr uint32_t kNumElemBytes = sizeof(nv_bfloat16); + constexpr uint32_t TMA_D_BLOCK_N = kSwizzleDMode == 0 ? BLOCK_N : (kSwizzleDMode / kNumElemBytes); + constexpr uint32_t WGMMA_M_PER_WARP = WGMMA::M / 4; + DG_STATIC_ASSERT(BLOCK_M % 8 == 0, "Invalid swizzling atom"); + DG_STATIC_ASSERT(BLOCK_N % TMA_D_BLOCK_N == 0 and BLOCK_N / TMA_D_BLOCK_N <= 32, + "Unaligned TMA store or too many TMA store instructions"); + DG_STATIC_ASSERT(TMA_D_BLOCK_N % 8 == 0, "Invalid TMA block N"); + + // Skip WGMMA store for the unfilled parts + if (not do_wgmma_store) + continue; + + // Wait last TMA store to be finished + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) + cute::tma_store_wait<0>(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Write back to shared memory using STSM and issue TMA stores + DG_STATIC_ASSERT(WGMMA::kNumAccum % 4 == 0, "Invalid STSM x2 vectorization"); + #pragma unroll + for (uint32_t local_idx = 0; local_idx < BLOCK_M / WAVE_BLOCK_M; ++ local_idx) { + auto m_offset = local_idx * WAVE_BLOCK_M; + auto shifted_accum = final_accum + WGMMA::kNumAccum * local_idx; + #pragma unroll + for (auto i = 0; i < WGMMA::kNumAccum / 4; ++ i) { + // Swizzle or padding into the correct address + uint8_t* smem_ptr = nullptr; + if constexpr (kSwizzleDMode > 0) { + // Calculate the swizzling atom offset and in-atom offset + constexpr uint32_t kNumBankGroupBytes = 16; + auto atom_offset = i / (TMA_D_BLOCK_N / 8), in_atom_offset = i % (TMA_D_BLOCK_N / 8); + + // Calculate the index of the bank group to be written in the atom + auto bank_group_index = in_atom_offset + lane_idx * (kSwizzleDMode / kNumBankGroupBytes); + + // Reshape the atom in another view and swizzle + // - original: `(BLOCK_M, kSwizzleDMode / kNumBankGroupBytes)` + // - new: `(BLOCK_M * kSwizzleDMode / kNumBankGroupBytes / 8, 8)` + constexpr bool kHasShortcut = (kSwizzleDMode / kNumBankGroupBytes) == 8; + auto row = kHasShortcut ? (in_atom_offset / 8 + lane_idx) : (bank_group_index / 8); + auto col = kHasShortcut ? (in_atom_offset) : (bank_group_index % 8); + col ^= row % (kSwizzleDMode / 16); + + // Add back into the base pointer + // NOTES: think twice before modifying this, as changes may affect the number of instructions + smem_ptr = reinterpret_cast(smem_d) + // Base pointer + warp_idx * (WGMMA_M_PER_WARP * kSwizzleDMode) + // Warp offset + m_offset * kSwizzleDMode + // Wave offset + atom_offset * BLOCK_M * kSwizzleDMode + // Swizzle atom offset (constants) + row * (kNumBankGroupBytes * 8) + col * kNumBankGroupBytes; // In-atom offset + } else { + // No swizzling, just padding + smem_ptr = reinterpret_cast(smem_d + (m_offset + warp_idx * WGMMA_M_PER_WARP + lane_idx) * BLOCK_N + i * 8); + } + + // NOTES: only 16 lanes' addresses are used + SM90_U32x2_STSM_N::copy( + __float22bfloat162_rn({shifted_accum[i * 4 + 0], shifted_accum[i * 4 + 1]}), + __float22bfloat162_rn({shifted_accum[i * 4 + 2], shifted_accum[i * 4 + 3]}), + smem_ptr + ); + } + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(kNumWGMMAStoreThreads, 1); + + // Use TMA store to write back to global memory + // TODO: compatible with FP32 output + constexpr bool kWithGroupOffsetD = kGemmType == GemmType::MGroupedMasked; + DG_STATIC_ASSERT(kNumWGMMAStoreThreads >= BLOCK_N / TMA_D_BLOCK_N, "Too many TMA blocks"); + if (threadIdx.x < BLOCK_N / TMA_D_BLOCK_N) { + auto in_block_n_offset = threadIdx.x * TMA_D_BLOCK_N; + auto smem_ptr = smem_d + in_block_n_offset * BLOCK_M; + auto n_idx = epilogue_type_t::apply_index_n(n_block_idx * BLOCK_N + in_block_n_offset); + auto m_idx = scheduler.get_global_idx(shape_m, BLOCK_M, m_block_idx); + if constexpr (kGemmType == GemmType::Batched) { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_ptr, + n_idx, m_idx, scheduler.current_group_idx); + } else { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_ptr, n_idx, m_idx); + } + cute::tma_store_arrive(); + } + __syncwarp(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +}; // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d58c716242a09922157aa13e16cb8afac477904c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_mqa_logits.cuh @@ -0,0 +1,329 @@ +#pragma once + +#include +#include + +#include +#include +#include + +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +// ReSharper disable once CppNotAllPathsReturnValue +template +static constexpr int to_swizzle_cute_type() { + DG_STATIC_ASSERT(kHeadDim == 32 or kHeadDim == 64 or kHeadDim == 128, "Invalid swizzling"); + if constexpr (kHeadDim == 32) + return static_cast(cute::SM90::GMMA::LayoutType::B32); + if constexpr (kHeadDim == 64) + return static_cast(cute::SM90::GMMA::LayoutType::B64); + if constexpr (kHeadDim == 128) + return static_cast(cute::SM90::GMMA::LayoutType::B128); +} + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_mqa_logits(const uint32_t seq_len, const uint32_t seq_len_kv, + const uint32_t max_seqlen_k, const uint64_t stride_logits, + uint32_t* cu_seq_len_k_start, + uint32_t* cu_seq_len_k_end, + float* logits, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // TODO: consider TMA multicast + // For one block, we process `[q_start:q_end, h, d] @ [kv_start:kv_end, d] -> [q_start:q_end, kv_start:kv_end]` + // Q should be load only at once for a block + const auto& num_q_blocks = ceil_div(seq_len, BLOCK_Q); + + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // Prefetch TMA descriptors + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + if (threadIdx.x / 32 == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + // NOTES: weight may be unaligned + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = BLOCK_Q * kNumHeads * sizeof(float); + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Data on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + ( + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * i)); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + + SMEM_Q_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + + SMEM_WEIGHT_SIZE_PER_STAGE * kNumQStages + SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + + // TMA barriers + auto barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages + i); }); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + i); }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return barrier_ptr + (kNumQStages * 2 + kNumKVStages + i); }); + + // Initialize barriers + const bool& is_tma_load_warp = kNumMathThreads <= threadIdx.x and threadIdx.x < kNumMathThreads + 32; + if (is_tma_load_warp and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(kNumMathThreads); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 32; + constexpr uint32_t kNumMathRegisters = 112; + + // Block scheduler + uint32_t block_q_idx = blockIdx.x, q_iter_idx = 0; + const auto& get_next_block_q_idx = [&]() -> cute::tuple { + return {block_q_idx + gridDim.x, q_iter_idx + 1}; + }; + uint32_t seq_k_start[BLOCK_Q], seq_k_end[BLOCK_Q]; + const auto& load_schedule = [&](const uint32_t& q_iter_offset = 0) -> cute::tuple { + uint32_t start = cute::numeric_limits::max(); + uint32_t end = cute::numeric_limits::min(); + + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + const auto& q_idx = min(block_q_idx * BLOCK_Q + i, seq_len - 1); + seq_k_start[i] = __ldg(cu_seq_len_k_start + q_idx); + seq_k_end[i] = __ldg(cu_seq_len_k_end + q_idx); + start = min(start, min(seq_k_start[i], seq_len_kv)); + end = max(end, min(seq_k_end[i], seq_len_kv)); + } + start = start / 4 * 4; + return {(q_iter_idx + q_iter_offset) % kNumQStages, // Q pipeline stage + ((q_iter_idx + q_iter_offset) / kNumQStages) & 1, // Q pipeline phase + start, ceil_div(end - start, BLOCK_KV)}; // Task info + }; + + // KV pipeline + uint32_t num_total_kv_blocks = 0; + const auto& get_kv_pipeline = [&](const uint32_t& kv_block_idx) -> cute::tuple { + return { + (num_total_kv_blocks + kv_block_idx) % kNumKVStages, // KV pipeline stage + ((num_total_kv_blocks + kv_block_idx) / kNumKVStages) & 1 // KV pipeline phase + }; + }; + + if (threadIdx.x >= kNumMathThreads) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + + // Only the first warp remains + if (not is_tma_load_warp) + return; + + // Prefetch + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const auto& block_idx) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, block_idx * BLOCK_Q * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, block_idx * BLOCK_Q); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + }; + if (cute::elect_one_sync() and block_q_idx < num_q_blocks) + issue_tma_q(0, block_q_idx); + + // Only the first lane persistently schedules over blocks + if (cute::elect_one_sync()) { + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(1), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait Q consumer release + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + + // Issue TMA Q + if (const auto& next_block_q_idx = cute::get<0>(get_next_block_q_idx()); next_block_q_idx < num_q_blocks) + issue_tma_q(q_stage_idx, next_block_q_idx); + + // Issue TMA KV + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Wait consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, kv_start + kv_block_idx * BLOCK_KV); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], kv_start + kv_block_idx * BLOCK_KV, 0); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + num_total_kv_blocks += num_kv_blocks; + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& thread_idx = threadIdx.x % kNumMathThreads; + const auto& warp_idx = __shfl_sync(0xffffffff, thread_idx / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + float accum[WGMMA::kNumAccum], weights[BLOCK_Q][kNumHeads / 4]; + + const auto& warp_offset = warp_idx * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + while (block_q_idx < num_q_blocks) { + CUTE_TIE_DECL(load_schedule(), q_stage_idx, q_phase, kv_start, num_kv_blocks); + + // Wait TMA Q arrival + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + + // Compute over KV blocks + #pragma unroll + for (uint32_t kv_block_idx = 0; kv_block_idx < num_kv_blocks; ++ kv_block_idx) { + // Compute `[BLOCK_Q * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [BLOCK_Q, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_block_idx), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + warp_offset + v_1_offset); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == kNumMathThreads / 2, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + (warpgroup_idx * WGMMA::M) * kHeadDim + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, + to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + const auto& kv_offset = kv_start + kv_block_idx * BLOCK_KV + warp_offset; + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == BLOCK_Q, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < BLOCK_Q; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + const uint32_t& q_idx = block_q_idx * BLOCK_Q + i; + if constexpr (kIsCompressedLogits) { + if (seq_k_start[i] <= kv_offset + v_0_offset and kv_offset + v_0_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_0_offset - seq_k_start[i]] = v_0; + if (seq_k_start[i] <= kv_offset + v_1_offset and kv_offset + v_1_offset < seq_k_end[i]) + logits[q_idx * stride_logits + kv_offset + v_1_offset - seq_k_start[i]] = v_1; + } else { + logits[q_idx * stride_logits + kv_offset + v_0_offset] = v_0; + logits[q_idx * stride_logits + kv_offset + v_1_offset] = v_1; + } + } + } + num_total_kv_blocks += num_kv_blocks; + + // Release Q empty + empty_q_barriers[q_stage_idx]->arrive(); + + // Jump to the next block + CUTE_TIE(get_next_block_q_idx(), block_q_idx, q_iter_idx); + } + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..482a85a80fce29aa949b464070b0b20fb55ae030 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_fp8_paged_mqa_logits.cuh @@ -0,0 +1,413 @@ +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(32, 1) +void smxx_paged_mqa_logits_metadata(const uint32_t batch_size, const uint32_t next_n, const bool is_context_lens_2d, + const uint32_t* context_lens, uint32_t* schedule_metadata) { + DG_STATIC_ASSERT(kAlignedBatchSize % 32 == 0, "Invalid aligned batch size"); + const uint32_t lane_idx = get_lane_idx(); + + uint32_t num_segs[kAlignedBatchSize / 32]; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + const uint32_t q_idx = k * 32 + lane_idx; + const uint32_t lens_idx = (is_context_lens_2d ? q_idx * next_n + next_n - 1 : q_idx); + const uint32_t& context_len = (q_idx < batch_size ? __ldg(context_lens + lens_idx) : 0); + num_segs[k] = ceil_div(context_len, SPLIT_KV); + } + + __shared__ uint32_t prefix_sum[kAlignedBatchSize]; + uint32_t sum = 0; + #pragma unroll + for (uint32_t k = 0; k < kAlignedBatchSize / 32; ++ k) { + uint32_t x = num_segs[k]; + #pragma unroll + for (uint32_t offset = 1; offset < 32; offset <<= 1) { + const uint32_t& y = __shfl_up_sync(0xffffffff, x, offset); + x += (lane_idx >= offset ? y : 0); + } + x += sum; + prefix_sum[k * 32 + lane_idx] = x; + sum = __shfl_sync(0xffffffff, x, 31); + } + + const uint32_t& q = sum / kNumSMs, r = sum % kNumSMs; + for (uint32_t sm_idx = lane_idx; sm_idx <= kNumSMs; sm_idx += 32) { + uint32_t seg_starts = sm_idx * q + min(sm_idx, r); + uint32_t q_idx = 0; + while (q_idx < batch_size and prefix_sum[q_idx] <= seg_starts) + ++ q_idx; + const uint32_t& kv_split_idx = (q_idx == 0 ? seg_starts : seg_starts - prefix_sum[q_idx - 1]); + __syncwarp(); + + schedule_metadata[sm_idx * 2] = q_idx; + schedule_metadata[sm_idx * 2 + 1] = kv_split_idx; + } +} + +template +struct PagedMQALogitsScheduler { + uint32_t batch_size; + const uint32_t* context_lens; + + uint32_t current_q_idx, current_kv_idx; + uint32_t end_q_idx, end_kv_idx; + uint32_t current_num_kv; + + __device__ __forceinline__ uint32_t get_num_kv(const uint32_t& q_idx) { + const auto& lens_idx = (kIsContextLens2D ? q_idx * kNextN + kNextN - 1 : q_idx); + return q_idx < batch_size ? ceil_div(__ldg(context_lens + lens_idx), BLOCK_KV) : 0; + } + + __device__ __forceinline__ explicit PagedMQALogitsScheduler(const uint32_t& batch_size, const uint32_t& sm_idx, + const uint32_t* context_lens, const uint32_t* schedule_meta) { + this->batch_size = batch_size; + this->context_lens = context_lens; + + const auto& current_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx); + const auto& end_pack = __ldg(reinterpret_cast(schedule_meta) + sm_idx + 1); + current_q_idx = current_pack.x, current_kv_idx = current_pack.y * kNumBlocksPerSplit; + end_q_idx = end_pack.x, end_kv_idx = end_pack.y * kNumBlocksPerSplit; + + current_num_kv = get_num_kv(current_q_idx); + } + + __device__ __forceinline__ bool fetch_next_task(uint32_t &q_idx, uint32_t &kv_idx, uint32_t &num_kv) { + q_idx = current_q_idx; + kv_idx = current_kv_idx; + num_kv = current_num_kv; + + if (q_idx == end_q_idx and kv_idx == end_kv_idx) + return false; + + current_kv_idx += kNumBlocksPerSplit; + if (current_kv_idx >= current_num_kv) { + ++ current_q_idx; + current_kv_idx = 0; + current_num_kv = get_num_kv(current_q_idx); + } + + return true; + } + + __device__ __forceinline__ bool exist_q_idx(const uint32_t& q_idx) const { + return q_idx < end_q_idx or q_idx == end_q_idx and 0 < end_kv_idx; + } +}; + +using namespace deep_gemm::sm90; + +template +__global__ __launch_bounds__(kNumTMAThreads + kNumMathThreads, 1) +void sm90_fp8_paged_mqa_logits(const uint32_t batch_size, + const uint64_t logits_stride, const uint64_t block_table_stride, + const uint32_t* context_lens, float* logits, + const uint32_t* block_table, const uint32_t* schedule_meta, + const __grid_constant__ cute::TmaDescriptor tensor_map_q, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv, + const __grid_constant__ cute::TmaDescriptor tensor_map_kv_scales, + const __grid_constant__ cute::TmaDescriptor tensor_map_weights) { + // Types + using WGMMA = typename FP8MMASelector::type; + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // NOTES: use `__shfl_sync` to encourage NVCC to use unified registers + const auto& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + const auto& warpgroup_idx = warp_idx / 4; + const auto& lane_idx = get_lane_idx(); + + // Prefetch TMA descriptors + static constexpr uint32_t kNumMathWarpGroups = kNumMathThreads / 128; + DG_STATIC_ASSERT(kNumTMAThreads == 128 and kNumMathThreads % 128 == 0, "Invalid threads"); + DG_STATIC_ASSERT(SPLIT_KV == BLOCK_KV * kNumMathWarpGroups, "Invalid `SPLIT_KV`"); + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_q); + cute::prefetch_tma_descriptor(&tensor_map_kv); + cute::prefetch_tma_descriptor(&tensor_map_kv_scales); + cute::prefetch_tma_descriptor(&tensor_map_weights); + } + __syncwarp(); + + // Shared memory configs + static constexpr uint32_t kSwizzleAlignment = kHeadDim * 8; + static constexpr uint32_t SMEM_Q_SIZE_PER_STAGE = kNextN * kNumHeads * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_WEIGHT_SIZE_PER_STAGE = kNextN * kNumHeads * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE = constexpr_align(SMEM_WEIGHT_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_Q_PIPE_SIZE = kNumQStages * (SMEM_Q_SIZE_PER_STAGE + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE) + + constexpr_align(kNumQStages * 8 * 2, kSwizzleAlignment); + + static constexpr uint32_t SMEM_KV_SIZE_PER_STAGE = BLOCK_KV * kHeadDim * sizeof(__nv_fp8_e4m3); + static constexpr uint32_t SMEM_KV_SCALE_SIZE_PER_STAGE = BLOCK_KV * sizeof(float); + static constexpr uint32_t ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE = constexpr_align(SMEM_KV_SCALE_SIZE_PER_STAGE, kSwizzleAlignment); + static constexpr uint32_t SMEM_KV_PIPE_SIZE = kNumKVStages * (SMEM_KV_SIZE_PER_STAGE + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE) + + constexpr_align(kNumKVStages * 8 * 2, kSwizzleAlignment); + + // Align to swizzling alignment bytes + extern __shared__ __align__(kSwizzleAlignment) uint8_t smem_buffer[]; + DG_STATIC_ASSERT(SMEM_Q_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + DG_STATIC_ASSERT(SMEM_KV_SIZE_PER_STAGE % kSwizzleAlignment == 0, "Unaligned TMA swizzling"); + + // Q data and barriers on shared memory + auto smem_q = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + SMEM_Q_SIZE_PER_STAGE * i); + }); + auto smem_weights = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + SMEM_Q_SIZE_PER_STAGE * kNumQStages + ALIGNED_SMEM_WEIGHT_SIZE_PER_STAGE * i); + }); + auto q_barrier_ptr = reinterpret_cast(smem_weights[kNumQStages]); + auto full_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + i; }); + auto empty_q_barriers = PatternVisitor([&](const uint32_t& i) { return q_barrier_ptr + (kNumQStages + i); }); + + // Separate math warpgroups and tma load warps into KV groups + // Each math warpgroup corresponds to a tma load warp + const auto& kv_group_idx = __shfl_sync(0xffffffff, threadIdx.x >= kNumMathThreads ? (threadIdx.x - kNumMathThreads) / 32 : warpgroup_idx, 0); + + // Per group KV data and barriers on shared memory + const auto& smem_offset = SMEM_Q_PIPE_SIZE + SMEM_KV_PIPE_SIZE * kv_group_idx; + auto smem_kv = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast<__nv_fp8_e4m3*>(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * i); + }); + auto smem_kv_scales = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + smem_offset + SMEM_KV_SIZE_PER_STAGE * kNumKVStages + ALIGNED_SMEM_KV_SCALE_SIZE_PER_STAGE * i); + }); + auto kv_barrier_ptr = reinterpret_cast(smem_kv_scales[kNumKVStages]); + auto full_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + i; }); + auto empty_kv_barriers = PatternVisitor([&](const uint32_t& i) { return kv_barrier_ptr + kNumKVStages + i; }); + + // Initialize barriers + if (warp_idx >= kNumMathThreads / 32 and cute::elect_one_sync()) { + if (kv_group_idx == 0) { + #pragma unroll + for (uint32_t i = 0; i < kNumQStages; ++ i) { + full_q_barriers[i]->init(1); + empty_q_barriers[i]->init(kNumMathThreads); + } + } + if (kv_group_idx < kNumMathWarpGroups) { + #pragma unroll + for (uint32_t i = 0; i < kNumKVStages; ++ i) { + full_kv_barriers[i]->init(1); + empty_kv_barriers[i]->init(128); + } + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + // Register reconfigurations + constexpr uint32_t kNumTMARegisters = 64; + constexpr uint32_t kNumMathRegisters = 104; + + // Scheduler + auto scheduler = PagedMQALogitsScheduler(batch_size, blockIdx.x, context_lens, schedule_meta); + DG_STATIC_ASSERT(SPLIT_KV % BLOCK_KV == 0, "Unaligned SPLIT_KV"); + + // Q and KV pipeline + const auto& get_q_pipeline = [=](const uint32_t& q_iter_idx) -> cute::tuple { + return {q_iter_idx % kNumQStages, (q_iter_idx / kNumQStages) & 1}; // Q pipeline stage and phase + }; + const auto& get_kv_pipeline = [=](const uint32_t& kv_iter_idx) -> cute::tuple { + return {kv_iter_idx % kNumKVStages, (kv_iter_idx / kNumKVStages) & 1}; // KV pipeline stage and phase + }; + uint32_t q_iter_idx = 0, kv_iter_idx = 0; + + if (warp_idx >= kNumMathThreads / 32) { + // TMA warp-group for loading data + cutlass::arch::warpgroup_reg_dealloc(); + if (kv_group_idx >= kNumMathWarpGroups) + return; + + const auto& issue_tma_q = [&](const uint32_t& stage_idx, const uint32_t& q_idx) { + if (kv_group_idx == 0 and cute::elect_one_sync()) { + tma_copy(&tensor_map_q, full_q_barriers[stage_idx], smem_q[stage_idx], 0, q_idx * kNextN * kNumHeads); + tma_copy(&tensor_map_weights, full_q_barriers[stage_idx], smem_weights[stage_idx], 0, q_idx); + full_q_barriers[stage_idx]->arrive_and_expect_tx(SMEM_Q_SIZE_PER_STAGE + SMEM_WEIGHT_SIZE_PER_STAGE); + } + }; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx, num_kv; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + bool fetched_next_task; + + // Prefetch the first Q + if ((fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv))) + issue_tma_q(0, next_q_idx), q_iter_idx = 1; + + int kv_block_idx_ptr = 32; + uint32_t kv_block_idx_storage; + + while (fetched_next_task) { + // Prefetch next Q when current Q changes + bool prefetch_q = (q_idx != next_q_idx and scheduler.exist_q_idx(next_q_idx + 1)); + q_idx = next_q_idx; + kv_idx = next_kv_idx; + num_kv = next_num_kv; + + // Wait Q consumer release and issue TMA Q + if (prefetch_q) { + CUTE_TIE_DECL(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + empty_q_barriers[q_stage_idx]->wait(q_phase ^ 1); + issue_tma_q(q_stage_idx, q_idx + 1); + } + + // Read KV block index + // TODO: deal with `-1`? + if (kv_idx == 0 or kv_block_idx_ptr == 32) { + kv_block_idx_ptr = 0; + kv_block_idx_storage = (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups < num_kv ? + __ldg(block_table + q_idx * block_table_stride + (kv_idx + kv_group_idx + lane_idx * kNumMathWarpGroups)) : 0); + } + const auto& kv_block_idx = __shfl_sync(0xffffffff, kv_block_idx_storage, kv_block_idx_ptr ++); + + // Wait KV consumer release + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + empty_kv_barriers[kv_stage_idx]->wait(kv_phase ^ 1); + + // Issue TMA KV + if (cute::elect_one_sync()) { + tma_copy(&tensor_map_kv, full_kv_barriers[kv_stage_idx], + smem_kv[kv_stage_idx], 0, 0, 1, kv_block_idx); + tma_copy(&tensor_map_kv_scales, full_kv_barriers[kv_stage_idx], + smem_kv_scales[kv_stage_idx], 0, kv_block_idx); + full_kv_barriers[kv_stage_idx]->arrive_and_expect_tx(SMEM_KV_SIZE_PER_STAGE + SMEM_KV_SCALE_SIZE_PER_STAGE); + } + + // Fetch next task + fetched_next_task = scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv); + } + } else { + // Math warp-groups for WGMMA + cutlass::arch::warpgroup_reg_alloc(); + + float accum[WGMMA::kNumAccum], weights[kNextN][kNumHeads / 4]; + const auto& sub_warp_offset = (warp_idx % 4) * 16; + const auto& v_0_offset = lane_idx / 4 + 0; + const auto& v_1_offset = lane_idx / 4 + 8; + + // Initialize `q_idx` outside `[0, batch_size)` to indicate it was none + uint32_t q_idx = batch_size, kv_idx; + uint32_t next_q_idx, next_kv_idx, next_num_kv; + uint32_t q_stage_idx, q_phase; + + while (scheduler.fetch_next_task(next_q_idx, next_kv_idx, next_num_kv)) { + // Current Q changes + if (q_idx != next_q_idx) { + // Release Last Q empty + if (q_iter_idx > 0) + empty_q_barriers[(q_iter_idx - 1) % kNumQStages]->arrive(); + + // Wait TMA Q arrival + CUTE_TIE(get_q_pipeline(q_iter_idx ++), q_stage_idx, q_phase); + full_q_barriers[q_stage_idx]->wait(q_phase); + + // Read weights + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + #pragma unroll + for (uint32_t j = 0; j < kNumHeads / 4; ++ j) + weights[i][j] = ld_shared(smem_weights[q_stage_idx] + i * kNumHeads + (j / 2) * 8 + (j & 1) + (lane_idx % 4) * 2); + } + } + + // Get current Q and KV index + q_idx = next_q_idx; + kv_idx = next_kv_idx; + + // Calculate KV offset in advance + auto kv_offset = q_idx * kNextN * logits_stride + ((kv_idx + kv_group_idx) * BLOCK_KV + sub_warp_offset); + + // Compute `[kNextN * kNumHeads, kHeadDim] @ [BLOCK_KV, kHeadDim] -> [kNextN, BLOCK_KV]` + // Wait TMA KV arrival + CUTE_TIE_DECL(get_kv_pipeline(kv_iter_idx ++), kv_stage_idx, kv_phase); + full_kv_barriers[kv_stage_idx]->wait(kv_phase); + + // Issue WGMMA + DG_STATIC_ASSERT(BLOCK_KV == 64, "Invalid block size"); + DG_STATIC_ASSERT(kHeadDim % WGMMA::K == 0, "Invalid head dim"); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + #pragma unroll + for (uint32_t k = 0; k < kHeadDim / WGMMA::K; ++ k) { + auto desc_a = make_smem_desc(smem_kv[kv_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + auto desc_b = make_smem_desc(smem_q[q_stage_idx] + k * WGMMA::K, to_swizzle_cute_type(), 0, kHeadDim * 8); + WGMMA::wgmma(desc_a, desc_b, accum, k); + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + + // Read per-KV scales + float scale_kv_0 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_0_offset); + float scale_kv_1 = ld_shared(smem_kv_scales[kv_stage_idx] + sub_warp_offset + v_1_offset); + + // Wait WGMMA + warpgroup_wait<0>(); + + // Release KV empty + empty_kv_barriers[kv_stage_idx]->arrive(); + + // Reduce over the head dim and store + static constexpr uint32_t kNumAccumPerReduce = kNumHeads / 2; + DG_STATIC_ASSERT(WGMMA::kNumAccum % kNumAccumPerReduce == 0, "Invalid accumulation"); + DG_STATIC_ASSERT(WGMMA::kNumAccum / kNumAccumPerReduce == kNextN, "Invalid accumulation"); + DG_STATIC_ASSERT(kNumHeads % 8 == 0, "Invalid head"); + #pragma unroll + for (uint32_t i = 0; i < kNextN; ++ i) { + auto shifted_accum = accum + i * kNumAccumPerReduce; + const auto& transform = [&](const uint32_t& j) { + return fmaxf(shifted_accum[j], 0) * weights[i][(j / 4) * 2 + (j & 1)]; + }; + + // Intra-thread reduction + float sum[4] = {transform(0), transform(1), transform(2), transform(3)}; + #pragma unroll + for (uint32_t j = 1; j < kNumHeads / 8; ++ j) { + #pragma unroll + for (uint32_t k = 0; k < 4; k ++) + sum[k] += transform(j * 4 + k); + } + float v_0 = (sum[0] + sum[1]) * scale_kv_0; + float v_1 = (sum[2] + sum[3]) * scale_kv_1; + + // Inter-thread reduction + #pragma unroll + for (uint32_t j = 0; j < 2; ++ j) { + const auto& offset = static_cast(1u << j); + v_0 += __shfl_xor_sync(0xffffffffu, v_0, offset); + v_1 += __shfl_xor_sync(0xffffffffu, v_1, offset); + } + + // Store into the global memory + // NOTES: we have redundant writes here, consider more carefully + logits[kv_offset + i * logits_stride + v_0_offset] = v_0; + logits[kv_offset + i * logits_stride + v_1_offset] = v_1; + } + } + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh new file mode 100644 index 0000000000000000000000000000000000000000..e3bf98478923a2bf560e69e6ecc802d218fb82c1 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/sm90_tf32_hc_prenorm_gemm.cuh @@ -0,0 +1,287 @@ +#pragma once +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wunknown-attributes" + +#include +#include + +#include +#include +#include + +namespace deep_gemm { + +using namespace deep_gemm::sm90; + +template +__device__ __forceinline__ +uint32_t get_swizzled_bank_group_idx(const uint32_t& offset, const uint32_t& lane_idx) { + constexpr uint32_t kGroupsInSwizzleRange = kSwizzleMode / kSwizzleBase; + + const auto& bank_group_idx = offset + lane_idx * kGroupsInSwizzleRange; + + constexpr uint32_t kNumBankGroups = 128 / kSwizzleBase; + constexpr bool kHasShortcut = kGroupsInSwizzleRange == kNumBankGroups; + auto row = kHasShortcut ? (offset / kNumBankGroups + lane_idx) : (bank_group_idx / kNumBankGroups); + auto col = kHasShortcut ? (offset) : (bank_group_idx % kNumBankGroups); + col ^= row % kGroupsInSwizzleRange; + + return (row * kNumBankGroups + col) % kGroupsInSwizzleRange; +} + +template +__global__ void __launch_bounds__(kNumMathThreads + kNumTMAThreads, 1) +sm90_tf32_hc_prenorm_gemm_impl(const uint32_t shape_m, + const __grid_constant__ cute::TmaDescriptor tensor_map_a, + const __grid_constant__ cute::TmaDescriptor tensor_map_b, + const __grid_constant__ cute::TmaDescriptor tensor_map_d, + float* sqr_sum) { +#if (defined(__CUDA_ARCH__) and (__CUDA_ARCH__ >= 900)) or defined(__CLION_IDE__) + using Barrier = cutlass::arch::ClusterTransactionBarrier; + + // kSwizzleAMode and kSwizzleBMode must be 128 for now + constexpr uint32_t kSwizzleAMode = cute::min(BLOCK_K * sizeof(nv_bfloat16), 128); + constexpr uint32_t kSwizzleBMode = cute::min(BLOCK_K * sizeof(float), 128); + DG_STATIC_ASSERT(BLOCK_K == 64, "Invalid block K"); + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + DG_STATIC_ASSERT(kSwizzleBMode == 128, "Invalid swizzle B mode"); + + DG_STATIC_ASSERT(kSwizzleCDMode / sizeof(float) == BLOCK_N, "Invalid block N"); + DG_STATIC_ASSERT(kNumMathThreads == 128, "Invalid MMA threads"); + + // Utils + const auto warp_idx = cutlass::canonical_warp_idx_sync(); + const auto lane_idx = get_lane_idx(); + + // Align to 1024 bytes for swizzle-128B + extern __shared__ __align__(1024) uint8_t smem_buffer[]; + + // Share memory sizes + constexpr uint32_t SMEM_CD_SIZE = BLOCK_M * kSwizzleCDMode; + constexpr uint32_t SMEM_A_SIZE_PER_STAGE = BLOCK_M * BLOCK_K * sizeof(nv_bfloat16); + constexpr uint32_t SMEM_B_SIZE_PER_STAGE = BLOCK_N * BLOCK_K * sizeof(float); + DG_STATIC_ASSERT(SMEM_CD_SIZE % 1024 == 0, "Shared memory of A/B must be aligned to 1024 bytes"); + + if (warp_idx == 0 and cute::elect_one_sync()) { + cute::prefetch_tma_descriptor(&tensor_map_a); + cute::prefetch_tma_descriptor(&tensor_map_b); + cute::prefetch_tma_descriptor(&tensor_map_d); + } + + // Data on shared memory (layout as ordered below) + // Fill D/A/B pointers + auto smem_cd = reinterpret_cast(smem_buffer); + auto smem_a = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + i * SMEM_A_SIZE_PER_STAGE)); + }); + auto smem_b = PatternVisitor([&](const uint32_t& i) { + return reinterpret_cast(smem_buffer + (SMEM_CD_SIZE + kNumStages * SMEM_A_SIZE_PER_STAGE + i * SMEM_B_SIZE_PER_STAGE)); + }); + + // Fill barriers + auto barrier_start_ptr = reinterpret_cast(smem_buffer + SMEM_CD_SIZE + kNumStages * (SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE)); + auto full_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (i); }); + auto empty_barriers = PatternVisitor([=](const uint32_t& i) { return barrier_start_ptr + (kNumStages + i); }); + + // Initialize barriers + if (warp_idx == 1 and cute::elect_one_sync()) { + #pragma unroll + for (uint32_t i = 0; i < kNumStages; ++ i) { + full_barriers[i]->init(1); + empty_barriers[i]->init(128); + } + + // Make initialized barrier visible in async proxy + cutlass::arch::fence_barrier_init(); + } + __syncthreads(); + + constexpr uint32_t kNumKBlocks = constexpr_ceil_div(SHAPE_K, BLOCK_K); + constexpr uint32_t kNumKBlocksPerSplit = kNumKBlocks / kNumSplits; + constexpr uint32_t kRemainKBlocks = kNumKBlocks % kNumSplits; + const uint32_t block_idx = __shfl_sync(0xffffffff, blockIdx.x, 0); + const uint32_t m_block_idx = block_idx / kNumSplits; + const uint32_t k_split_idx = block_idx % kNumSplits; + const uint32_t k_offset = (k_split_idx * kNumKBlocksPerSplit + cute::min(k_split_idx, kRemainKBlocks)) * BLOCK_K; + const uint32_t m_offset = shape_m * k_split_idx; + const uint32_t num_total_stages = kNumKBlocksPerSplit + (k_split_idx < kRemainKBlocks); + constexpr uint32_t kNumTMARegisters = 40; + constexpr uint32_t kNumMathRegisters = 256; + + // TMA load warp + if (warp_idx == kNumMathThreads / 32 and cute::elect_one_sync()) { + cutlass::arch::warpgroup_reg_dealloc(); + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait consumer release + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + + // Compute offsets + uint32_t m_idx = m_block_idx * BLOCK_M; + uint32_t k_idx = k_offset + s * BLOCK_K; + + // Issue TMAs + tma_copy(&tensor_map_a, full_barriers[stage_idx], smem_a[stage_idx], k_idx, m_idx); + tma_copy(&tensor_map_b, full_barriers[stage_idx], smem_b[stage_idx], k_idx, 0); + + // Arrive at full barriers + constexpr uint32_t kNumArrivalBytes = SMEM_A_SIZE_PER_STAGE + SMEM_B_SIZE_PER_STAGE; + full_barriers[stage_idx]->arrive_and_expect_tx(kNumArrivalBytes); + } + + for (uint32_t s = num_total_stages; s < num_total_stages + kNumStages; ++ s) { + const auto& stage_idx = s % kNumStages; + empty_barriers[stage_idx]->wait(((s / kNumStages) & 1) ^ 1); + } + } else if (warp_idx < kNumMathThreads / 32) { + cutlass::arch::warpgroup_reg_alloc(); + + DG_STATIC_ASSERT(BLOCK_M == 64, "Invalid block M"); + DG_STATIC_ASSERT(BLOCK_K * sizeof(nv_bfloat16) == kSwizzleAMode, "Invalid block K"); + constexpr uint32_t BLOCK_M_PER_WARP = BLOCK_M / 4; + constexpr uint32_t WGMMA_M = 64; + constexpr uint32_t WGMMA_N = BLOCK_N; + constexpr uint32_t WGMMA_K = 8; + + using WGMMA = typename TF32MMASelector::type; + float accum[WGMMA::kNumAccum] = {0}; + + constexpr uint32_t kNumBankGroupBytes = 16; + constexpr uint32_t kNumElemsPerBankGroup = kNumBankGroupBytes / sizeof(nv_bfloat16); + constexpr uint32_t kNumLoads = BLOCK_K / kNumElemsPerBankGroup; + float sqr_sum_acc_0 = 0; + float sqr_sum_acc_1 = 0; + + #pragma unroll kNumStages < 8 ? kNumStages : kNumStages / 2 + for (uint32_t s = 0; s < num_total_stages; ++ s) { + // Wait TMA arrival + const auto& stage_idx = s % kNumStages; + full_barriers[stage_idx]->wait((s / kNumStages) & 1); + + constexpr uint32_t kNumRegPerWgmma = WGMMA::M * WGMMA::K / 128; + constexpr uint32_t kNumWgmmaPerBlockK = BLOCK_K / WGMMA::K; + + float a[kNumRegPerWgmma * kNumWgmmaPerBlockK]; + // Assume swizzle A mode is 128 + DG_STATIC_ASSERT(kSwizzleAMode == 128, "Invalid swizzle A mode"); + + // Load BF16 A fragment from shared memory into registers, and transpose to FP32 + uint32_t row = warp_idx * 16 + lane_idx / 4; + #pragma unroll + for (uint32_t i = 0; i < kNumLoads; ++ i) { + // Refer to the A layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-a + uint32_t bank_group_idx = (row ^ i) % 8; + nv_bfloat16* a_bf16_smem_ptr_upper = smem_a[stage_idx] + row * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + nv_bfloat16* a_bf16_smem_ptr_lower = smem_a[stage_idx] + (row + 8) * BLOCK_K + bank_group_idx * kNumElemsPerBankGroup; + + uint32_t elem_offset = lane_idx % 4; + nv_bfloat16 a_bf16[kNumRegPerWgmma]; + a_bf16[0] = a_bf16_smem_ptr_upper[elem_offset]; + a_bf16[2] = a_bf16_smem_ptr_upper[elem_offset + 4]; + a_bf16[1] = a_bf16_smem_ptr_lower[elem_offset]; + a_bf16[3] = a_bf16_smem_ptr_lower[elem_offset + 4]; + + auto a_bf16x2_ptr = reinterpret_cast(a_bf16); + auto a_float2_ptr = reinterpret_cast(a); + float2 a_float2_0 = __bfloat1622float2(a_bf16x2_ptr[0]); + float2 a_float2_1 = __bfloat1622float2(a_bf16x2_ptr[1]); + a_float2_ptr[i * 2 + 0] = a_float2_0; + a_float2_ptr[i * 2 + 1] = a_float2_1; + sqr_sum_acc_0 += a_float2_0.x * a_float2_0.x + a_float2_1.x * a_float2_1.x; + sqr_sum_acc_1 += a_float2_0.y * a_float2_0.y + a_float2_1.y * a_float2_1.y; + } + + warpgroup_wait<0>(); + if (s > 0) + empty_barriers[(s - 1) % kNumStages]->arrive(); + + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + warpgroup_arrive(); + + constexpr int kNumElemsInSwizzleRange = 128 / sizeof(float); + constexpr uint32_t kNumWgmmaInSwizzleRange = kNumElemsInSwizzleRange / WGMMA::K; + DG_STATIC_ASSERT(BLOCK_K % kNumElemsInSwizzleRange == 0, "Invalid block K"); + + #pragma unroll + for (int i = 0; i < BLOCK_K / kNumElemsInSwizzleRange; i++) { + #pragma unroll + for (int k = 0; k < kNumElemsInSwizzleRange / WGMMA::K; k++) { + auto b_desc = make_smem_desc(smem_b[stage_idx] + i * BLOCK_N * kNumElemsInSwizzleRange + k * WGMMA::K, 1); + WGMMA::wgmma(a + (i * kNumWgmmaInSwizzleRange + k) * kNumRegPerWgmma, b_desc, accum, 1); + } + } + warpgroup_commit_batch(); + #pragma unroll + for (uint32_t i = 0; i < WGMMA::kNumAccum; ++ i) + warpgroup_fence_operand(accum[i]); + } + + const auto& reduced_sum_0 = warp_reduce_sum<4>(sqr_sum_acc_0); + const auto& reduced_sum_1 = warp_reduce_sum<4>(sqr_sum_acc_1); + + const auto& m_idx = m_block_idx * BLOCK_M + (warp_idx * BLOCK_M_PER_WARP + lane_idx / 4); + if (lane_idx % 4 == 0) { + if (m_idx < shape_m) + sqr_sum[m_offset + m_idx] = reduced_sum_0; + if (m_idx + 8 < shape_m) + sqr_sum[m_offset + m_idx + 8] = reduced_sum_1; + } + warpgroup_wait<0>(); + empty_barriers[(num_total_stages-1) % kNumStages]->arrive(); + + // Write accum to shared memory + // Every 2 threads (one pair) will write to the same bank group (16 bytes). + // Refer to the D layout in https://docs.nvidia.com/cuda/parallel-thread-execution/#wgmma-64n8-d + uint32_t is_odd_pair = lane_idx / 2 % 2; + + // Four threads per group; write the data to the same row. + uint32_t row_idx = lane_idx / 4; + + // Even/odd index pairs write to the same column, we need to reorder idx: + // group even pair indices consecutively, and likewise for odd ones. + uint32_t reordered_pair_idx = is_odd_pair * 8 + row_idx; + + auto shifted_smem_ptr = reinterpret_cast(smem_cd) + + (warp_idx * BLOCK_M_PER_WARP + row_idx) * kSwizzleCDMode + // Row offset, each warp has 16 rows + lane_idx % 2 * 8; // One thread of a pair writes 8 bytes + + #pragma unroll + for (uint32_t i = 0; i < (kSwizzleCDMode / sizeof(float)) / 4; i += 2) { + // Get the swizzled bank group index (16 bytes per group) + uint32_t bank_group_idx = get_swizzled_bank_group_idx(i + is_odd_pair, reordered_pair_idx); + auto smem_ptr = shifted_smem_ptr + bank_group_idx * kNumBankGroupBytes; // Col offset, 16 bytes per group + + // 0/1 write to the same row, 2/3 write to another row + auto values = reinterpret_cast(accum + i * 2); + st_shared(smem_ptr, values[0], values[1]); + st_shared(smem_ptr + 8 * kSwizzleCDMode, values[2], values[3]); + } + cute::tma_store_fence(); + cutlass::arch::NamedBarrier::sync(128, 1); + + // Issue TMA stores + if (warp_idx == 0 and cute::elect_one_sync()) { + if constexpr (kNumSplits == 1) { + cute::SM90_TMA_STORE_2D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M); + } else { + cute::SM90_TMA_STORE_3D::copy(&tensor_map_d, smem_cd, 0, m_block_idx * BLOCK_M, k_split_idx); + } + cute::tma_store_arrive(); + } + } +#else + if (blockIdx.x == 0 and threadIdx.x == 0) + DG_DEVICE_ASSERT(false and "This kernel only support sm_90a"); +#endif +} + +} // namespace deep_gemm + +#pragma clang diagnostic pop diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh new file mode 100644 index 0000000000000000000000000000000000000000..cc9e5e6b0c7ce95acf0b7149221dc4d4f0f83a21 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_clean_logits.cuh @@ -0,0 +1,67 @@ +#pragma once + +#include +#include + +#include + +namespace deep_gemm { + +template +__global__ __launch_bounds__(kNumWarps * 32, 1) +void smxx_clean_logits(const uint32_t seq_len, const uint32_t seq_len_kv, const uint64_t stride_logits, + const uint32_t* cu_seq_len_k_start, const uint32_t* cu_seq_len_k_end, float* logits) { + const uint32_t& num_sms = gridDim.x; + const uint32_t& sm_idx = blockIdx.x; + const uint32_t& warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + constexpr float neg_inf = -cute::numeric_limits::infinity(); + + // Allocate filled `-inf` shared memory + extern __shared__ __align__(1024) float smem_buffer[]; + #pragma unroll + for (uint32_t i = threadIdx.x; i < BLOCK_KV; i += kNumWarps * 32) + smem_buffer[i] = neg_inf; + cute::tma_store_fence(); + __syncthreads(); + + // Assign sequence to each warp + const auto& assign_task = [&](const uint32_t& num, const uint32_t& idx, + const uint32_t& start, const uint32_t& total) -> cute::tuple { + const auto& per = total / num, rem = total % num; + return {start + idx * per + min(idx, rem), per + (idx < rem)}; + }; + CUTE_TIE_DECL(assign_task(num_sms, sm_idx, 0, seq_len), sm_seq_start, sm_seq_len); + CUTE_TIE_DECL(assign_task(kNumWarps, warp_idx, sm_seq_start, sm_seq_len), warp_seq_start, warp_seq_len); + + if (cute::elect_one_sync()) { + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + + for (uint32_t left = 0; left < seq_len_kv; left += BLOCK_KV) { + const auto& right = min(left + BLOCK_KV, static_cast(stride_logits)); + if (right <= ks or ke <= left) { + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (right - left) * sizeof(float)); + } else { + if (left < aligned_ks) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + left, (aligned_ks - left) * sizeof(float)); + if (aligned_ke < right) + cute::SM90_BULK_COPY_S2G::copy(smem_buffer, logits + i * stride_logits + aligned_ke, (right - aligned_ke) * sizeof(float)); + } + } + } + } + + for (uint32_t i = warp_seq_start; i < warp_seq_start + warp_seq_len; ++ i) { + const auto& ks = cu_seq_len_k_start == nullptr ? 0 : __ldg(cu_seq_len_k_start + i / kNextN); + const auto& ke = __ldg(cu_seq_len_k_end + i / kNextN) - kNextN + i % kNextN + 1; + const auto& aligned_ks = ks / 4 * 4, aligned_ke = (ke + 3) / 4 * 4; + for (uint32_t j = aligned_ks; j < ks; ++ j) + logits[i * stride_logits + j] = neg_inf; + for (uint32_t j = ke; j < aligned_ke; ++ j) + logits[i * stride_logits + j] = neg_inf; + } +} + +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh new file mode 100644 index 0000000000000000000000000000000000000000..bea7000276c3e382c1acfeff545d6181351849b6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/deep_gemm/impls/smxx_layout.cuh @@ -0,0 +1,176 @@ +#pragma once + +#include + +namespace deep_gemm { + +template +__global__ void transpose_fp32(const float* sf, float* out, const uint32_t mn) { + typedef typename Vectorized::vec_t in_vec_t; + constexpr static uint32_t kNumElemsPerVec = sizeof(in_vec_t) / sizeof(float); + constexpr static uint32_t SF_VEC_K = SF_K / kNumElemsPerVec; + + // Shapes and strides + extern __shared__ float smem_buffer[]; + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(float)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the block + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * SF_K; + const auto& local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + + // Load + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_VEC_K; i += kNumThreads) { + auto in_vec = __ldg(local_sf + i); + const auto& in_values = reinterpret_cast(&in_vec); + + const auto& row = i / SF_VEC_K, col = (i % SF_VEC_K) * kNumElemsPerVec; + #pragma unroll + for (uint32_t j = 0; j < kNumElemsPerVec; ++ j) + smem_buffer[row * PADDED_SF_K + col + j] = in_values[j]; + } + __syncthreads(); + + // Store + #pragma unroll + for (uint32_t i = threadIdx.x; i < in_block_mn * SF_K; i += kNumThreads) { + const auto& sf_k_idx = i / in_block_mn, mn_idx = i % in_block_mn; + const auto& global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; + out[sf_k_idx * tma_aligned_mn + global_mn_idx] = ld_shared(smem_buffer + mn_idx * PADDED_SF_K + sf_k_idx); + } +} + +// NOTES: the two kernels below always pack the K dimension + +template +__global__ void transpose_and_pack_fp32_into_ue8m0(float* sf, uint32_t* out, const uint32_t mn) { + extern __shared__ uint32_t smem_buffer[]; + + // Shapes and strides + constexpr auto kNumPackedSFK = constexpr_ceil_div(SF_K, 4u); + constexpr auto kNumTMAAlignedElems = static_cast(16 / sizeof(int)); + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto tma_aligned_mn = align(mn, kNumTMAAlignedElems); + + // Shift into the group + sf = sf + static_cast(blockIdx.y) * mn * SF_K; + out = out + static_cast(blockIdx.y) * tma_aligned_mn * kNumPackedSFK; + + // Load FP32 SFs + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block size"); + const auto local_sf = reinterpret_cast(sf + static_cast(blockIdx.x) * (BLOCK_MN * SF_K)); + const auto num_values = in_block_mn * SF_K; + const auto num_uint4 = num_values / 4; + #pragma unroll + for (uint32_t i = threadIdx.x; i < num_uint4; i += kNumThreads) { + const auto& [x, y, z, w] = __ldg(reinterpret_cast(local_sf) + i); + st_shared(reinterpret_cast(smem_buffer) + i, x, y, z, w); + } + + // Fill unaligned values as well + if (const auto unaligned_idx = num_uint4 * 4 + threadIdx.x; unaligned_idx < num_values) + st_shared(smem_buffer + unaligned_idx, __ldg(local_sf + unaligned_idx)); + __syncthreads(); + + // Pack into UE8M0 and store + #pragma unroll + for (uint32_t i = threadIdx.x; i < (kNumPackedSFK * BLOCK_MN); i += kNumThreads) { + const auto sf_k_pack_idx = i / BLOCK_MN, mn_idx = i % BLOCK_MN; + + // Load shared memory + uint32_t values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + const auto sf_k_idx = sf_k_pack_idx * 4 + j; + values[j] = sf_k_idx < SF_K ? ld_shared(smem_buffer + mn_idx * SF_K + sf_k_idx) : 0; + } + + // Pack and store + uint32_t packed = 0; + packed |= (values[0] >> 23u); + packed |= (values[1] >> 15u); + packed |= (values[2] >> 7u); + packed |= (values[3] << 1u); + if (const auto global_mn_idx = blockIdx.x * BLOCK_MN + mn_idx; global_mn_idx < mn) + out[sf_k_pack_idx * tma_aligned_mn + global_mn_idx] = packed; + } +} + +template +__global__ void pack_fp32_into_ue8m0(float* sf, uint32_t* out, uint32_t* ks, + const uint32_t mn, uint32_t sf_k, const uint32_t packed_sf_k) { + // Always packing the K dimension + // NOTES: should also assert `mn % 4 == 0` at launch + DG_STATIC_ASSERT(kTransposed, "Currently only support transposed SFs (MN-major)"); + DG_STATIC_ASSERT(BLOCK_MN % 4 == 0, "Invalid block sizes"); + DG_STATIC_ASSERT(BLOCK_PACKED_SF_K == kNumThreads / 32, "Invalid block sizes"); + + // Shapes and strides + const auto in_block_mn = min(BLOCK_MN, mn - blockIdx.x * BLOCK_MN); + const auto in_block_mn_uint4 = in_block_mn / 4; + const auto in_block_packed_sf_k = min(BLOCK_PACKED_SF_K, packed_sf_k - blockIdx.y * BLOCK_PACKED_SF_K); + + // Shift into the right block along MN + sf += blockIdx.x * BLOCK_MN; + out += blockIdx.x * BLOCK_MN; + + // Each warp is responsible for a packed row + const auto warp_idx = threadIdx.x / 32; + const auto lane_idx = get_lane_idx(); + const auto packed_sf_k_idx = static_cast(blockIdx.y) * BLOCK_PACKED_SF_K + warp_idx; + if (warp_idx >= in_block_packed_sf_k) + return; + + // Make an offset on the input + uint32_t input_offset = 0; + if constexpr (kNumGroups > 1) { + // Load each group's size + DG_STATIC_ASSERT(kNumGroups <= 128, "Too many groups"); + uint32_t group_ks[4]; + #pragma unroll + for (uint32_t i = 0; i < 4; ++ i) { + const auto group_idx = lane_idx * 4 + i; + group_ks[i] = group_idx < kNumGroups ? __ldg(ks + group_idx) : 0; + } + __syncwarp(); + + // Make the offset + sf_k = 0; + auto sum_packed_sf_k = 0; + #pragma unroll + for (uint32_t i = 0; i < kNumGroups; ++ i) { + const auto sf_k_in_group = __shfl_sync(0xffffffff, group_ks[i % 4] / 128, i / 4); + sf_k += sf_k_in_group; + sum_packed_sf_k += ceil_div(sf_k_in_group, 4u); + if (packed_sf_k_idx < sum_packed_sf_k) + break; + if (const auto remainder = sf_k_in_group % 4; remainder > 0) + input_offset += 4 - remainder; + } + } + + for (uint32_t mn_idx = get_lane_idx(); mn_idx < in_block_mn_uint4; mn_idx += 32) { + // Load + uint4 values[4]; + #pragma unroll + for (uint32_t j = 0; j < 4; ++ j) { + values[j] = make_uint4(0, 0, 0, 0); + if (const auto sf_k_idx = packed_sf_k_idx * 4 + j - input_offset; sf_k_idx < sf_k) + values[j] = __ldg(reinterpret_cast(sf + sf_k_idx * mn) + mn_idx); + } + + // Pack and store + uint4 packed; + packed.x = (values[0].x >> 23u) | (values[1].x >> 15u) | (values[2].x >> 7u) | (values[3].x << 1u); + packed.y = (values[0].y >> 23u) | (values[1].y >> 15u) | (values[2].y >> 7u) | (values[3].y << 1u); + packed.z = (values[0].z >> 23u) | (values[1].z >> 15u) | (values[2].z >> 7u) | (values[3].z << 1u); + packed.w = (values[0].w >> 23u) | (values[1].w >> 15u) | (values[2].w >> 7u) | (values[3].w << 1u); + reinterpret_cast(out + packed_sf_k_idx * mn)[mn_idx] = packed; + } +} + +} // namespace deep_gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h new file mode 100644 index 0000000000000000000000000000000000000000..9cfb284e6716575c81a5f73e7746ce13664d2250 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/options.h @@ -0,0 +1,121 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Cutlass command line parser +#include "cutlass/util/command_line.h" + +class Options { +public: + + bool help; + bool good; + std::vector extent; ///< extent of tile to fill + std::vector stride; ///< stride vector for layout function + std::vector output_shape; ///< output shape + int vectorize; ///< sequences of consecutive output elements are concatenated into a vector + /// if, and only if, they were consecutive in source memory + +public: + + /// Options + Options(): + help(false), + good(true), + extent({32, 8}), + stride({32}), + output_shape({16, 8}), + vectorize(1) { + + } + + /// Constructs from command line parser + Options(cutlass::CommandLine const & cmd_line): help(false), good(true) { + + if (cmd_line.check_cmd_line_flag("help") || + cmd_line.check_cmd_line_flag("h")) { + + help = true; + } + + if (cmd_line.check_cmd_line_flag("extent")) { + cmd_line.get_cmd_line_arguments("extent", extent); + } + else { + extent = {32, 8}; + } + + if (cmd_line.check_cmd_line_flag("stride")) { + cmd_line.get_cmd_line_arguments("stride", stride); + } + + int default_output_shape[] = {16, 8}; + + if (cmd_line.check_cmd_line_flag("output-shape")) { + cmd_line.get_cmd_line_arguments("output-shape", output_shape); + } + + for (int i = int(output_shape.size()); i < 2; ++i) { + output_shape.push_back(default_output_shape[i]); + } + + if (cmd_line.check_cmd_line_flag("vectorize")) { + cmd_line.get_cmd_line_argument("vectorize", vectorize); + } + else { + vectorize = 1; + } + + if (output_shape.front() % vectorize) { + + std::cerr << "Error: --vectorize=" << vectorize + << " must divide contiguous elements in --output-shape=" + << output_shape.at(0) << "," << output_shape.at(1) << std::endl; + + good = false; + } + } + + /// Prints usage statement + static void print_usage(std::ostream &out) { + out + << " Options:\n" + << " --help Displays this help message.\n" + << " --extent= Specifies the layout-specific extent (as comma-delimited array).\n" + << " --stride= Specifies the layout-specific stride vector (comma-delimited array)\n" + << " --output-shape= Specifies the dimensions of a row-major output matrix. \n" + << " --vectorize= If possible, vectorizes the output into vectors of consecutive elements\n"; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..c840c90e0ddf6f149a8ec82aba8a2c0fc6691d2c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/register_layout.h @@ -0,0 +1,59 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS layout visualization example +*/ + +#pragma once + +#include +#include + +#include "options.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct VisualizeLayoutBase { + virtual bool visualize(Options const &) = 0; + virtual bool verify(bool verbose, std::ostream &out) = 0; + virtual void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') = 0; + virtual std::ostream &print_help(std::ostream &out) { + return out; + } + virtual ~VisualizeLayoutBase() { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +void RegisterLayouts(std::map > &layouts); + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h new file mode 100644 index 0000000000000000000000000000000000000000..13318a05838039745e262a5a59a105701c4907cb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/03_visualize_layout/visualize_layout.h @@ -0,0 +1,383 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS layout visualization example +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/coord.h" +#include "cutlass/util/reference/host/tensor_foreach.h" + +#include "register_layout.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[Rank - 1] = vec.at(Rank - 1); + + if (Rank > 1) { + vector_to_coord(coord, vec); + } + } +}; + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + coord[0] = vec.at(0); + } +}; + +/// Permits copying dynamic vectors into static-length vectors +template +struct vector_to_coord { + + vector_to_coord(TensorCoord &coord, std::vector const &vec) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +std::ostream &operator<<(std::ostream &out, std::vector const &vec) { + auto it = vec.begin(); + if (it != vec.end()) { + out << *it; + for (++it; it != vec.end(); ++it) { + out << ", " << *it; + } + } + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + + vec.at(Rank - 1) = coord[Rank - 1]; + coord_to_vector(vec, coord); + } +}; + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + + vec.at(0) = coord[0]; + } +}; + +/// Permits copying static-length vectors into dynamic vectors +template +struct coord_to_vector { + + coord_to_vector(std::vector &vec, TensorCoord const &coord) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure representing an element in source memory +struct Element { + + std::vector coord; ///< logical coordinate of element (as vector) + int offset; ///< linear offset from source memory + int color; ///< enables coloring each element to indicate + + /// Default ctor + inline Element(): offset(-1), color(0) { } + + /// Construct from logical coordinate and initial offset + inline Element( + std::vector const &coord_, + int offset_, + int color_ = 0 + ): + coord(coord_), offset(offset_), color(color_) { } + + /// Returns true if element is in a defined state + inline bool valid() const { + return offset >= 0; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visualizes memory layouts by constructing a 'shape' +template +class VisualizeLayout : public VisualizeLayoutBase { +public: + + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using Stride = typename Layout::Stride; + +public: + + Options options; + Layout layout; + TensorCoord extent; + std::vector elements; + +public: + + /// Initializes the problem space + VisualizeLayout() { + + } + + /// visualization method + bool visualize(Options const &options_) { + + options = options_; + + if (options.extent.size() != TensorCoord::kRank) { + + std::cerr + << "--extent must have rank " << TensorCoord::kRank + << " (given: " << options.extent.size() << ")" << std::endl; + + return false; + } + + vector_to_coord(extent, options.extent); + + // Construct the layout for a packed tensor + if (options.stride.empty()) { + + layout = Layout::packed(extent); + } + else if (options.stride.size() != Stride::kRank) { + + std::cerr + << "--stride must have rank " << Stride::kRank + << " (given: " << options.stride.size() << ")" << std::endl; + + return false; + } + else { + // Stride from + Stride stride; + vector_to_coord(stride, options.stride); + + layout = Layout(stride); + } + + // Resize elements, setting elements to 'undefined' state + elements.resize(layout.capacity(extent)); + + // enumerate points in tensor space and assign + cutlass::reference::host::TensorForEachLambda( + extent, + [&](TensorCoord coord) { + + std::vector coord_vec(TensorCoord::kRank, 0); + coord_to_vector(coord_vec, coord); + + int offset = int(layout(coord)); + + if (offset >= int(elements.size())) { + std::cerr + << "Layout error - " << coord_vec + << " is out of range (computed offset: " << offset + << ", capacity: " << elements.size() << std::endl; + + throw std::out_of_range("(TensorForEach) layout error - coordinate out of range"); + } + + elements.at(offset) = Element(coord_vec, offset); + }); + + return true; + } + + /// Verifies the layout satisfies vectorization requirements + bool verify(bool verbose, std::ostream &out) { + return true; + } + +private: + + /// returns a pair (is_vectorizable, one_changing_rank) to determine if a + /// vector exists (consecutive logical coordinates or uniformly invalid) + /// at the given location. + std::pair< bool, int > _is_vectorizable(int i) const { + // (all elements are invalid) or + // (all elements are valid AND + // exactly one rank is changing AND + // elements are consecutive) + + // Don't need vectorization. + if (options.vectorize <= 2) return std::make_pair(false, -1); + + // Boundary check. + if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size())) + return std::make_pair(false, -1); + + // Check if either all elements are valid or invalid. + bool all_elements_invalid = std::all_of( + elements.begin() + i, elements.begin() + i + options.vectorize, + [](Element const &e) { return !e.valid(); }); + + bool all_elements_valid = std::all_of( + elements.begin() + i, elements.begin() + i + options.vectorize, + [](Element const &e) { return e.valid(); }); + + if (!all_elements_invalid && !all_elements_valid) + return std::make_pair(false, -1); + + // From here, it is vectorizable. + if (all_elements_invalid) return std::make_pair(true, -1); + + // Check if only exactly one rank is changing. + int one_changing_rank = -1; + for (int j = 0; j < options.vectorize; ++j) { + for (int r = 0; r < TensorCoord::kRank; ++r) { + if (elements.at(i + j).coord.at(r) != elements.at(i).coord.at(r)) { + if (one_changing_rank == -1) { + one_changing_rank = r; + } else if (one_changing_rank != r) { + return std::make_pair(false, -1); + } + } + } + } + + return std::make_pair(true, one_changing_rank); + } + + /// Prints a vector of elements + void _print_vector(std::ostream &out, int i, int one_changing_rank) { + Element const &base_element = elements.at(i); + if (base_element.valid()) { + out << "("; + for (int r = 0; r < TensorCoord::kRank; ++r) { + if (r) { + out << ", "; + } + + if (r == one_changing_rank) { + out + << base_element.coord.at(r) + << ".." + << (base_element.coord.at(r) + options.vectorize - 1); + } + else { + out << base_element.coord.at(r); + } + } + out << ")"; + } + else { + out << " "; + } + } + + /// Prints a single element + void _print_element(std::ostream &out, int k) { + Element const &element = elements.at(k); + if (element.valid()) { + out << "("; + for (int v = 0; v < TensorCoord::kRank; ++v) { + out << (v ? ", " : "") << element.coord.at(v); + } + out << ")"; + } + else { + out << " "; + } + } + +public: + + /// Pretty-prints the layout to the console + void print_csv(std::ostream &out, char delim = '|', char new_line = '\n') { + int row = -1; + + for (int i = 0; i < int(elements.size()); i += options.vectorize) { + if (i % options.output_shape.at(0)) { + out << delim; + } + else { + if (row >= 0) { + out << new_line; + } + ++row; + if (row == options.output_shape.at(1)) { + out << new_line; + row = 0; + } + } + + auto is_vector = _is_vectorizable(i); + + if (is_vector.first) { + _print_vector(out, i, is_vector.second); // print a vector starting at element i + } + else { + for (int j = 0; j < options.vectorize; ++j) { // print individual elements [i..i+j) + _print_element(out, i + j); + } + } + } + + out << new_line << std::flush; + } + + /// Help message + virtual std::ostream &print_help(std::ostream &out) { + out << "TensorCoord rank " << TensorCoord::kRank << ", Stride rank: " << Stride::kRank; + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h new file mode 100644 index 0000000000000000000000000000000000000000..df4cb76ad13df4d4a55b5fce810be5361fb081cd --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h @@ -0,0 +1,719 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + + +template +class B2bNonFusedConv2dRun { +public: + + using Conv2d0 = Conv2d0_; + using Conv2d1 = Conv2d1_; + using ElementAccumulator = typename Conv2d0::ElementAccumulator; + using ElementCompute = typename Conv2d0::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; + static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, + "Fused convolution operators must be the same"); + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_D0_computed; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bNonFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); + tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_computed.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1); + + // configure the operator + Conv2d0 conv2d_op_0; + Conv2d1 conv2d_op_1; + + typename Conv2d0::Arguments conv2d_args_0( + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, + tensor_D0_computed.device_ref(), + {alpha0, beta0}, + split_k_mode + ); + typename Conv2d1::Arguments conv2d_args_1( + problem_size_1, + tensor_D0_computed.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, + tensor_D1_computed.device_ref(), + {alpha1, beta1}, + split_k_mode + ); + + + cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); + + CUTLASS_CHECK(status); + + status = conv2d_op_1.initialize(conv2d_args_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = conv2d_op_0(); + CUTLASS_CHECK(status); + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run Conv2d + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_0(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float conv2d0Time, conv2d1Time, totalTime; + cudaEventElapsedTime(&conv2d0Time, start, stop1); + cudaEventElapsedTime(&conv2d1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; + std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0_computed.sync_host(); + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename Conv2d0::ElementA, + typename Conv2d0::LayoutA, + typename Conv2d0::ElementB, + typename Conv2d0::LayoutB, + typename Conv2d0::ElementC, + typename Conv2d0::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Conv2d0::LayoutC::Stride(0)}, + tensor_D0_reference.device_ref(), + alpha0, + beta0); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename Conv2d1::ElementA, + typename Conv2d1::LayoutA, + typename Conv2d1::ElementB, + typename Conv2d1::LayoutB, + typename Conv2d1::ElementC, + typename Conv2d1::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Conv2d1::LayoutC::Stride(0)}, + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" + << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +template +class B2bFusedConv2dRun { +public: + + using B2bConv2d = B2bConv2d_; + using ElementAccumulator = typename B2bConv2d::ElementAccumulator; + using ElementCompute = typename B2bConv2d::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Scale0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + ElementCompute alpha0, + ElementCompute alpha1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.K}); + tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + if(alpha0 == ElementCompute(0)) //per-channel scale + initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1, alpha0, alpha1); + + // configure the operator + B2bConv2d b2b_conv2d_op; + + typename B2bConv2d::Arguments b2b_conv2d_args( + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, + tensor_D1_computed.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + split_k_mode + ); + + cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" + << " problem_size_0.K = problem_size_1.C\n" + << " problem_size_1.R = problem_size_1.S = 1\n" + << " ThreadblockShape0::kN = problem_size_0.K\n" + << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; + } + + CUTLASS_CHECK(status); + + status = b2b_conv2d_op.initialize(b2b_conv2d_args); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + // + // Run the Conv2d + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + + // run conv2d operator + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float conv2dTime; + cudaEventElapsedTime(&conv2dTime, start, stop); + std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; + + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + ElementAccumulator, + typename B2bConv2d::LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias + >( + problem_size_0, + tensor_Z0_reference.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bConv2d::LayoutC::Stride(0)}, + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..f0e85cda3a4f36e3f20bbfeb8642ec0c62339a25 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_gemm_run.h @@ -0,0 +1,763 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + ElementCompute, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + ElementCompute, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } +}; + +template +struct B2bFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + mode, + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, + {alpha0, beta0}, + {alpha1, beta1}, + batch_count, + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D1.sync_host(); + + // + // Verify + // + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 + ); + + cutlass::reference::device::TensorScaleBiasGemmBatched< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size_1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..b6267a153b6ed399120585f885aaa33c24af9e24 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Containers for running grouped back-to-back GEMMs +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bFusedGroupedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGroupedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 1, -1, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + /// Executes one test + bool run( + std::vector problem_sizes_0, + std::vector problem_sizes_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + using HostTensorA = cutlass::HostTensor; + using HostTensorB = cutlass::HostTensor; + using HostTensorC = cutlass::HostTensor; + using HostTensorScale = cutlass::HostTensor; + using HostTensorZ = cutlass::HostTensor; + using HostTensorBias = cutlass::HostTensor; + + int problem_count = (int)problem_sizes_0.size(); + + std::vector host_tensor_A0(problem_count); + std::vector host_tensor_B0(problem_count); + std::vector host_tensor_C0(problem_count); + std::vector host_tensor_Scale0(problem_count); + std::vector host_tensor_Bias0(problem_count); + std::vector host_tensor_B1(problem_count); + std::vector host_tensor_C1(problem_count); + std::vector host_tensor_Bias1(problem_count); + std::vector host_tensor_D1(problem_count); + std::vector host_tensor_Z(problem_count); + std::vector host_tensor_ref_D0(problem_count); + std::vector host_tensor_ref_D1(problem_count); + + std::vector ref_A0(problem_count); + std::vector ref_B0(problem_count); + std::vector ref_C0(problem_count); + std::vector ref_Scale0(problem_count); + std::vector ref_Bias0(problem_count); + std::vector ref_B1(problem_count); + std::vector ref_C1(problem_count); + std::vector ref_Bias1(problem_count); + std::vector ref_D1(problem_count); + std::vector ref_Z(problem_count); + std::vector ref_ref_D0(problem_count); + std::vector ref_ref_D1(problem_count); + + for (int i = 0; i < problem_count; ++i) { + // + // Allocate the GEMM workspace + // + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk()); + host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn()); + host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn()); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn()); + host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn()); + host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn()); + host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()}); + host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017)); + if (alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + host_tensor_D1.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D0.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D1.at(i).host_view()); + + host_tensor_A0.at(i).sync_device(); + host_tensor_B0.at(i).sync_device(); + host_tensor_C0.at(i).sync_device(); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i).sync_device(); + host_tensor_Bias0.at(i).sync_device(); + host_tensor_B1.at(i).sync_device(); + host_tensor_C1.at(i).sync_device(); + host_tensor_Bias1.at(i).sync_device(); + host_tensor_D1.at(i).sync_device(); + host_tensor_ref_D0.at(i).sync_device(); + host_tensor_ref_D1.at(i).sync_device(); + + ref_A0.at(i) = (host_tensor_A0.at(i).device_ref()); + ref_B0.at(i) = (host_tensor_B0.at(i).device_ref()); + ref_C0.at(i) = (host_tensor_C0.at(i).device_ref()); + if (alpha0 == ElementCompute(0)) //per-channel scale + ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref()); + ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref()); + ref_B1.at(i) = (host_tensor_B1.at(i).device_ref()); + ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}; + ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref()); + ref_D1.at(i) = (host_tensor_D1.at(i).device_ref()); + ref_Z.at(i) = (host_tensor_Z.at(i).device_ref()); + ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref()); + ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref()); + } + + // + // Initialize the GEMM operator + // + + cutlass::DeviceAllocation device_ref_A0(problem_count); + device_ref_A0.copy_from_host(ref_A0.data()); + cutlass::DeviceAllocation device_ref_B0(problem_count); + device_ref_B0.copy_from_host(ref_B0.data()); + cutlass::DeviceAllocation device_ref_C0(problem_count); + device_ref_C0.copy_from_host(ref_C0.data()); + cutlass::DeviceAllocation device_ref_Scale0(problem_count); + device_ref_Scale0.copy_from_host(ref_Scale0.data()); + cutlass::DeviceAllocation device_ref_Bias0(problem_count); + device_ref_Bias0.copy_from_host(ref_Bias0.data()); + cutlass::DeviceAllocation device_ref_B1(problem_count); + device_ref_B1.copy_from_host(ref_B1.data()); + cutlass::DeviceAllocation device_ref_C1(problem_count); + device_ref_C1.copy_from_host(ref_C1.data()); + cutlass::DeviceAllocation device_ref_Bias1(problem_count); + device_ref_Bias1.copy_from_host(ref_Bias1.data()); + cutlass::DeviceAllocation device_ref_D1(problem_count); + device_ref_D1.copy_from_host(ref_D1.data()); + + cutlass::DeviceAllocation device_problem_sizes_0(problem_count); + device_problem_sizes_0.copy_from_host(problem_sizes_0.data()); + cutlass::DeviceAllocation device_problem_sizes_1(problem_count); + device_problem_sizes_1.copy_from_host(problem_sizes_1.data()); + + B2bGemm b2b_gemm_op; + + int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count); + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return false; + } + + typename B2bGemm::Arguments arguments{ + problem_count, + device_problem_sizes_0.get(), + device_problem_sizes_1.get(), + device_ref_A0.get(), + device_ref_B0.get(), + device_ref_C0.get(), + device_ref_Scale0.get(), + device_ref_Bias0.get(), + device_ref_B1.get(), + device_ref_C1.get(), + device_ref_D1.get(), + {alpha0, beta0}, + {alpha1, beta1}, + threadblock_count + }; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + for (int i = 0; i < problem_count; ++i) { + host_tensor_D1.at(i).sync_host(); + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator> + reference_gemm_1; + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + reference_gemm_0( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + ref_A0.at(i), + ref_B0.at(i), + ElementAccumulator(0), //beta = 0 + ref_Z.at(i), + ref_Z.at(i), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutC + > ( + problem_size_0, + ref_Z.at(i), + ref_ref_D0.at(i), + alpha0, + ref_Scale0.at(i), + ref_Bias0.at(i) + ); + + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + ref_ref_D0.at(i), + ref_B1.at(i), + beta1, + {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}, + ref_ref_D1.at(i) + ); + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view()); + } + cudaDeviceSynchronize(); + host_tensor_ref_D0.at(i).sync_host(); + host_tensor_ref_D1.at(i).sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + host_tensor_ref_D1.at(i).host_view(), + host_tensor_D1.at(i).host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "GEMM " << i << " in group\n" + << "A0 =\n" << host_tensor_A0.at(i).host_view() + << "\nB0 =\n" << host_tensor_B0.at(i).host_view() + << "\nC0 =\n" << host_tensor_C0.at(i).host_view() + << "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n" + << "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n" + << "\nB1 =\n" << host_tensor_B1.at(i).host_view() + << "\nC1 =\n" << host_tensor_C1.at(i).host_view() + << "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n" + << "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view() + << "\nComputed =\n" << host_tensor_D1.at(i).host_view(); + + return false; + } + } + return true; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h new file mode 100644 index 0000000000000000000000000000000000000000..4693e86423d9cdc142ee8bc84348e706c1c83280 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h @@ -0,0 +1,749 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_reorder.h" + +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/core_io.h" +#include "cutlass/util/tensor_view_io.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + + +template +class B2bInterleavedNonFusedConv2dRun { +public: + + using Conv2d0 = Conv2d0_; + using Conv2d1 = Conv2d1_; + using ElementAccumulator = typename Conv2d0::ElementAccumulator; + using ElementCompute = typename Conv2d0::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = Conv2d0::kConvolutionalOperator; + static_assert(kConvolutionalOperator == Conv2d1::kConvolutionalOperator, + "Fused convolution operators must be the same"); + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_B0_reordered; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_D0_computed; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_B1_reordered; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bInterleavedNonFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_Bias0.resize({1, 1, 1, problem_size_0.K}); + tensor_D0_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + + //Reorder B0 and B1 + cutlass::reorder_convK( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); + cutlass::reorder_convK( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_computed.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1); + + // configure the operator + Conv2d0 conv2d_op_0; + Conv2d1 conv2d_op_1; + + typename Conv2d0::Arguments conv2d_args_0( + problem_size_0, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_D0_computed.device_ref(), + {alpha0, beta0}, + split_k_mode + ); + typename Conv2d1::Arguments conv2d_args_1( + problem_size_1, + tensor_D0_computed.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1_computed.device_ref(), + {alpha1, beta1}, + split_k_mode + ); + + + cutlass::Status status = conv2d_op_0.initialize(conv2d_args_0); + + CUTLASS_CHECK(status); + + status = conv2d_op_1.initialize(conv2d_args_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = conv2d_op_0(); + CUTLASS_CHECK(status); + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run Conv2d + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_0(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + + for(int i = 0; i < runs; i++) { + // run conv2d operator + status = conv2d_op_1(); + CUTLASS_CHECK(status); + } + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float conv2d0Time, conv2d1Time, totalTime; + cudaEventElapsedTime(&conv2d0Time, start, stop1); + cudaEventElapsedTime(&conv2d1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "conv2d 0 time " << conv2d0Time / (float)runs << " ms\n"; + std::cout << "conv2d 1 time " << conv2d1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0_computed.sync_host(); + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename Conv2d0::ElementA, + typename Conv2d0::LayoutA, + typename Conv2d0::ElementB, + typename Conv2d0::LayoutB, + typename Conv2d0::ElementC, + typename Conv2d0::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_C0.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + beta0); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename Conv2d1::ElementA, + typename Conv2d1::LayoutA, + typename Conv2d1::ElementB, + typename Conv2d1::LayoutB, + typename Conv2d1::ElementC, + typename Conv2d1::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_interleaved_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 reference:\n" << tensor_D0_reference.host_view() << "\n" + << "\nD0 computed:\n" << tensor_D0_computed.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +template +class B2bInterleavedFusedConv2dRun { +public: + + using B2bConv2d = B2bConv2d_; + using ElementAccumulator = typename B2bConv2d::ElementAccumulator; + using ElementCompute = typename B2bConv2d::ElementCompute; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bConv2d::kConvolutionalOperator; + +public: + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_B0_reordered; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_Scale0; + cutlass::HostTensor tensor_Bias0; + cutlass::HostTensor tensor_Z0_reference; + cutlass::HostTensor tensor_D0_reference; + + cutlass::HostTensor tensor_B1; + cutlass::HostTensor tensor_B1_reordered; + cutlass::HostTensor tensor_C1; + cutlass::HostTensor tensor_Bias1; + cutlass::HostTensor tensor_D1_computed; + cutlass::HostTensor tensor_D1_reference; + + +public: + + B2bInterleavedFusedConv2dRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { + + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + int scope; + int bits = cutlass::sizeof_bits::value; + + if (bits <= 16) { + scope = 2; + } + else { + scope = 8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope, -scope, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + } + } + + void initialize( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + ElementCompute alpha0, + ElementCompute alpha1, + uint64_t seed = 2019) { + + tensor_A0.resize(implicit_gemm_tensor_a_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_B0_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_0)); + tensor_C0.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, problem_size_0.K}); + tensor_Bias0.resize({1, problem_size_0.K}); + tensor_Z0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_D0_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_0)); + tensor_B1.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_B1_reordered.resize(implicit_gemm_tensor_b_extent(kConvolutionalOperator, problem_size_1)); + tensor_C1.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_Bias1.resize({1, 1, 1, problem_size_1.K}); + tensor_D1_computed.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + tensor_D1_reference.resize(implicit_gemm_tensor_c_extent(kConvolutionalOperator, problem_size_1)); + + initialize_tensor(tensor_A0.host_view(), init_A, seed); + initialize_tensor(tensor_B0.host_view(), init_B, seed * 17); + initialize_tensor(tensor_C0.host_view(), init_C, seed * 39); + if(alpha0 == ElementCompute(0)) //per-channel scale + initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed * 61); + initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed * 83); + initialize_tensor(tensor_B1.host_view(), init_B, seed * 18); + initialize_tensor(tensor_C1.host_view(), init_C, seed * 40); + initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed * 84); + + //Reorder B0 and B1 + cutlass::reorder_convK<16, InterleavedK>( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_0)); + cutlass::reorder_convK( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), implicit_gemm_problem_size(kConvolutionalOperator, problem_size_1)); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0_reference.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1_computed.sync_device(); + tensor_D1_reference.sync_device(); + } + + /// Executes one test + bool run( + cutlass::conv::Conv2dProblemSize const &problem_size_0, + cutlass::conv::Conv2dProblemSize const &problem_size_1, + cutlass::conv::SplitKMode const &split_k_mode = cutlass::conv::SplitKMode::kSerial, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + initialize(problem_size_0, problem_size_1, alpha0, alpha1); + + // configure the operator + B2bConv2d b2b_conv2d_op; + + typename B2bConv2d::Arguments b2b_conv2d_args( + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1_reordered.device_ref(), + tensor_C1.device_ref(), + tensor_D1_computed.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + split_k_mode + ); + + cutlass::Status status = b2b_conv2d_op.can_implement(b2b_conv2d_args); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.N*P*Q = problem_size_1.N*P*Q\n" + << " problem_size_0.K = problem_size_1.C\n" + << " problem_size_1.R = problem_size_1.S = 1\n" + << " ThreadblockShape0::kN = problem_size_0.K\n" + << " ThreadblockShape1::kN = problem_size_1.K" << std::endl; + } + + CUTLASS_CHECK(status); + + status = b2b_conv2d_op.initialize(b2b_conv2d_args); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + // + // Run the Conv2d + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + + // run conv2d operator + status = b2b_conv2d_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float conv2dTime; + cudaEventElapsedTime(&conv2dTime, start, stop); + std::cout << "Fusion time " << conv2dTime / (float)runs << " ms\n"; + + tensor_D1_computed.sync_host(); + + bool passed = false; + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + ElementAccumulator, + typename B2bConv2d::LayoutC, + ElementAccumulator, + ElementAccumulator + >( + kConvolutionalOperator, + problem_size_0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + tensor_Z0_reference.device_ref(), + tensor_Z0_reference.device_ref(), + ElementAccumulator(1), // intermediate alpha = 1 + ElementAccumulator(0) // beta = 0 + ); + + cutlass::reference::device::TensorScaleBiasConv2d< + ElementAccumulator, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + typename B2bConv2d::LayoutScaleBias, + cutlass::NumericConverterClamp + >( + problem_size_0, + tensor_Z0_reference.device_ref(), + tensor_D0_reference.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D0_reference.device_view()); + } + + cutlass::reference::device::Conv2d< + typename B2bConv2d::ElementA, + typename B2bConv2d::LayoutA, + typename B2bConv2d::ElementB, + typename B2bConv2d::LayoutB, + typename B2bConv2d::ElementC, + typename B2bConv2d::LayoutC, + ElementCompute, + ElementAccumulator, + cutlass::NumericConverterClamp + >( + kConvolutionalOperator, + problem_size_1, + tensor_D0_reference.device_ref(), + tensor_B1.device_ref(), + tensor_C1.device_ref(), + tensor_D1_reference.device_ref(), + alpha1, + beta1); + + if(relu) { + cutlass::reference::device::TensorReLu(tensor_D1_reference.device_view()); + } + + cudaError_t result = cudaDeviceSynchronize(); + CHECK_TRUE(result == cudaSuccess); + + // sync host (copy device data to host) for dumping error output in case of mismatches + tensor_D0_reference.sync_host(); + tensor_D1_reference.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0_reference.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_computed.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1_reference.host_view()), 0); + + passed = cutlass::reference::host::TensorEquals( + tensor_D1_computed.host_view(), + tensor_D1_reference.host_view()); + + CHECK_TRUE(passed); + + if (!passed) { + std::stringstream fname; + + fname << "error_B2bImplicitGemm_device_interleaved_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream results(fname.str()); + + results << problem_size_0 << std::endl; + results << problem_size_1 << std::endl; + + results + << "\nA0:\n" << tensor_A0.host_view() << "\n" + << "\nB0:\n" << tensor_B0.host_view() << "\n" + << "\nB0_reordered:\n" << tensor_B0_reordered.host_view() << "\n" + << "\nC0:\n" << tensor_C0.host_view() << "\n" + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1:\n" << tensor_B1.host_view() << "\n" + << "\nB1_reordered:\n" << tensor_B1_reordered.host_view() << "\n" + << "\nC1:\n" << tensor_C1.host_view() << "\n" + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\nD1 reference:\n" << tensor_D1_reference.host_view() << "\n" + << "\nD1 computed:\n" << tensor_D1_computed.host_view(); + + + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..453f44cd0c44f9320085460935a3478bc884ff0e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h @@ -0,0 +1,798 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template +struct B2bInterleavedNonFusedGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedNonFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + //Reorder B0 and B1 + cutlass::reorder_column( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + typename Gemm0::Arguments arguments_0{ + problem_size_0, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0} + }; + + typename Gemm1::Arguments arguments_1{ + problem_size_1, + tensor_D0.device_ref(), + tensor_B1_reordered.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1} + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + cutlass::Status status = gemm_op_0.initialize(arguments_0); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion time " << totalTime / (float)runs << " ms\n"; + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size_0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } +}; + +template +struct B2bInterleavedFusedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::B2bGemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bInterleavedFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + + cutlass::HostTensor< + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Scale0; + + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); + + cutlass::HostTensor< + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); + + cutlass::HostTensor< + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + if(alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(tensor_Scale0.host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + //Reorder B0 + cutlass::reorder_column<16>( + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0); + cutlass::reorder_column( + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1); + + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_B0_reordered.sync_device(); + tensor_C0.sync_device(); + if(alpha0 == ElementCompute(0)) //per-channel scale + tensor_Scale0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_B1_reordered.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + // tensor_Bias0_batched.sync_device(); + + // + // Initialize the GEMM operator + // + + typename B2bGemm::Arguments arguments{ + mode, + problem_size_0, + problem_size_1, + tensor_A0.device_ref(), + tensor_B0_reordered.device_ref(), + tensor_C0.device_ref(), + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + tensor_B1_reordered.device_ref(), + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, + {alpha0, beta0}, + {alpha1, beta1}, + batch_count, + }; + + B2bGemm b2b_gemm_op; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + tensor_D1.sync_host(); + + // + // Verify + // + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, + ElementAccumulator(0), //beta = 0 + reference_Z0.device_ref(), + reference_Z0.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 + ); + + cutlass::reference::device::TensorScaleBiasGemmBatched< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutScaleBias + > ( + problem_size_0, + reference_Z0.device_ref(), + reference_D0.device_ref(), + alpha0, + tensor_Scale0.device_ref(), + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( + problem_size_1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 + {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_interleaved_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nScale0:\n" << tensor_Scale0.host_view() << "\n" + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nB1_reordered =\n" << tensor_B1_reordered.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..f9b2f49cd67708e655b17fdae086a731db66571a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -0,0 +1,352 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" + +#include "kernel/b2b_gemm.h" +#include "kernel/default_b2b_gemm.h" +#include "kernel/default_b2b_gemm_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_ = ElementC_, + /// Operator class tag + typename OperatorClass_ = arch::OpClassSimt, + /// Tag indicating architecture to tune for + typename ArchTag_ = arch::Sm70, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Epilogue output operator + typename EpilogueOutputOp1_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::EpilogueOutputOp, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + /// Stage accumulator in shared memory + bool SmemAccumulator = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class B2bGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB = LayoutB_; + using TensorRefB = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape0 = ThreadblockShape0_; + using ThreadblockShape1 = ThreadblockShape1_; + using WarpShape0 = WarpShape0_; + using WarpShape1 = WarpShape1_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + /// Derived types + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; + + /// Define the kernel + using B2bGemmKernel = typename kernel::DefaultB2bGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + kStages, + Operator, + SmemAccumulator + >::B2bGemmKernel; + + using Arguments = typename B2bGemmKernel::Arguments; + +private: + + /// Kernel parameters object + typename B2bGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + B2bGemm() { } + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + Status status = B2bGemmKernel::can_implement( + args.problem_size_0, + args.problem_size_1, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.batch_count); + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.batch_count); +// cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( +// args.problem_size_1, +// {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, +// args.batch_count); + + // Initialize the Params structure + params_ = typename B2bGemmKernel::Params{ + args.mode, + args.problem_size_0, + args.problem_size_1, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_Scale0.non_const_ref(), + args.ref_Bias0.non_const_ref(), + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.batch_stride_A0, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C1, + args.batch_stride_D1, + args.batch_stride_Bias0, + args.batch_stride_Scale0, + args.epilogue0, + args.epilogue1, + static_cast(workspace), + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_Scale0.reset(args.ref_Scale0.non_const_ref().data()); + params_.ref_Bias0.reset(args.ref_Bias0.non_const_ref().data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(B2bGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..e780537aff4cdc733a66044c771eb7bcc6a7ec03 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -0,0 +1,300 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Implicit GEMM +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/default_b2b_conv2d_fprop_sm75.h" +#include "kernel/default_b2b_conv2d_fprop_sm80.h" +#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h" +#include "kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h" + +namespace cutlass { +namespace conv { +namespace device { + +template +class B2bImplicitGemmConvolution { +public: + + using B2bImplicitGemmKernel = B2bImplicitGemmKernel_; + + using ElementA = typename B2bImplicitGemmKernel::ElementA; + using LayoutA = typename B2bImplicitGemmKernel::LayoutA; + using ElementB = typename B2bImplicitGemmKernel::ElementB; + using LayoutB = typename B2bImplicitGemmKernel::LayoutB; + using ElementC = typename B2bImplicitGemmKernel::ElementC; + using LayoutC = typename B2bImplicitGemmKernel::LayoutC; + using ElementAccumulator = typename B2bImplicitGemmKernel::ElementAccumulator; + using ElementCompute = typename B2bImplicitGemmKernel::ElementCompute; + using ElementScaleBias = typename B2bImplicitGemmKernel::ElementScaleBias; + using LayoutScaleBias = typename B2bImplicitGemmKernel::LayoutScaleBias; + using OperatorClass = typename B2bImplicitGemmKernel::OperatorClass; + using ArchTag = typename B2bImplicitGemmKernel::ArchTag; + using ThreadblockShape0 = typename B2bImplicitGemmKernel::ThreadblockShape0; + using ThreadblockShape1 = typename B2bImplicitGemmKernel::ThreadblockShape1; + using WarpShape0 = typename B2bImplicitGemmKernel::WarpShape0; + using WarpShape1 = typename B2bImplicitGemmKernel::WarpShape1; + using InstructionShape = typename B2bImplicitGemmKernel::InstructionShape; + using ThreadblockSwizzle = typename B2bImplicitGemmKernel::ThreadblockSwizzle; + using EpilogueOutputOp0 = typename B2bImplicitGemmKernel::EpilogueOutputOp0; + using EpilogueOutputOp1 = typename B2bImplicitGemmKernel::EpilogueOutputOp1; + static int const kStages = B2bImplicitGemmKernel::kStages; + static int const kConvDim = B2bImplicitGemmKernel::kConvDim; + using WarpMmaOperator0 = typename B2bImplicitGemmKernel::WarpMmaOperator0; + using WarpMmaOperator1 = typename B2bImplicitGemmKernel::WarpMmaOperator1; + using ArchMmaOperator = typename B2bImplicitGemmKernel::ArchMmaOperator; + using MathOperator = typename B2bImplicitGemmKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = B2bImplicitGemmKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = B2bImplicitGemmKernel::kIteratorAlgorithm; + + static int const kWarpCount = + (ThreadblockShape0::kM / WarpShape0::kM) * + (ThreadblockShape0::kN / WarpShape0::kN); + + /// Argument structure + using Arguments = typename B2bImplicitGemmKernel::Arguments; + +private: + + /// Kernel parameters object + typename B2bImplicitGemmKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + B2bImplicitGemmConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = B2bImplicitGemmKernel::B2bMma::IteratorA0::can_implement(args.problem_size_0); + if (Status::kSuccess != status) { + return status; + } + + status = B2bImplicitGemmKernel::B2bMma::IteratorB0::can_implement(args.problem_size_0); + if (Status::kSuccess != status) { + return status; + } + + status = B2bImplicitGemmKernel::B2bMma::IteratorB1::can_implement(args.problem_size_1); + if (Status::kSuccess != status) { + return status; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + // Determine if fusion sizes are valid + + cutlass::gemm::GemmCoord problem_size_0 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0); + cutlass::gemm::GemmCoord problem_size_1 = implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1); + + if(problem_size_0.m() != problem_size_1.m()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() != problem_size_1.k()) + return Status::kErrorInvalidProblem; + + if(args.problem_size_1.R != 1 || args.problem_size_1.S != 1) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() > ThreadblockShape0::kN) + return Status::kErrorInvalidProblem; + + if(problem_size_1.n() > ThreadblockShape1::kN) + return Status::kErrorInvalidProblem; + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t workspace_bytes = 0; + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0), + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices); + + if(args.split_k_mode == SplitKMode::kParallel) { + + // Split-K parallel: CTAs in k-dimension write the partial results in a temporary workspace. + // The user needs to call a reduction operator to obtain the final output tensor + workspace_bytes = + sizeof(ElementAccumulator) * + size_t(cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, args.problem_size_0)) * + size_t(grid_tiled_shape.k()); + } + + else if(args.split_k_mode == SplitKMode::kSerial && args.problem_size_0.split_k_slices > 1) { + + // Split-K serial: The user workspace is used to store semaphore and serialize writing the + // final reduced output to user's output tensor + workspace_bytes = sizeof(int) * size_t(grid_tiled_shape.m()) * size_t(grid_tiled_shape.n()); + } + + return workspace_bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + if (args.problem_size_0.split_k_slices > 1) { + + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + cudaError_t status = cudaMemsetAsync(workspace, 0, get_workspace_size(args), stream); + + if (status != cudaSuccess) { + return Status::kErrorInternal; + } + } + + // initialize the params structure from the arguments + params_ = typename B2bImplicitGemmKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A0 = args.ref_A0.data(); + params_.ptr_B0 = args.ref_B0.data(); + params_.ptr_C0 = args.ref_C0.data(); + params_.ptr_Scale0 = args.ref_Scale0.data(); + params_.ptr_Bias0 = args.ref_Bias0.data(); + params_.ptr_B1 = args.ref_B1.data(); + params_.ptr_C1 = args.ref_C1.data(); + params_.ptr_D1 = args.ref_D1.data(); + params_.output_op_0 = args.output_op_0; + params_.output_op_1 = args.output_op_1; + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + int smem_size = int(sizeof(typename B2bImplicitGemmKernel::SharedStorage)); + + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace device +} // namespace conv +} // namespace cutlass +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..ef11f4ab83fb0aff78d7f494bb57df7392474e48 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -0,0 +1,811 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "kernel/b2b_gemm_grouped_problem_visitor.h" +#include "threadblock/grouped_threadblock_swizzle.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +namespace detail { + +/// Utility struct for returning the type of the problem visitor used by the swizzling function, +/// if it is a grouped swizzling function, or a default visitor. This is used only for defining +/// the parameters of the problem visitor used in GroupedParams. +template < + typename B2bMma_, + typename ThreadblockSwizzle_, + typename Enable = void +> +struct ProblemVisitorOrDefault; + +/// Return a generic problem visitor for GEMM problems +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = B2bGemmGroupedProblemVisitor::value>; +}; + +/// Return the problem visitor specified by the swizzling function +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = typename ThreadblockSwizzle_::ProblemVisitor; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct B2bGemm { + + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using OutputOp0 = typename B2bMma::OutputOp; + using OutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA0 = typename B2bMma::IteratorA0::Element; + using LayoutA0 = typename B2bMma::IteratorA0::Layout; + using ElementB0 = typename B2bMma::IteratorB0::Element; + using LayoutB0 = typename B2bMma::IteratorB0::Layout; + using ElementB1 = typename B2bMma::IteratorB1::Element; + using LayoutB1 = typename B2bMma::IteratorB1::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; + + /// Data types needed for higher-level containers. In some cases, a single type must be exposed + /// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from + /// the second GEMM (other than for ElementA/ElementB) + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + + static ComplexTransform const kTransformA = B2bMma::kTransformA; + static ComplexTransform const kTransformB = B2bMma::kTransformB; + using Operator = typename B2bMma::Operator0; + + using OperatorClass = typename Operator::OperatorClass; + using ThreadblockShape = typename B2bMma::Shape0; + using WarpShape = typename Operator::Shape; + using InstructionShape = typename Operator::InstructionShape; + using ArchTag = typename B2bMma::ArchTag; + + static int const kStages = B2bMma::kStages; + static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + using Mma = B2bMma; + using EpilogueOutputOp = OutputOp1; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + GemmCoord problem_size_0{0,0,0}; + GemmCoord problem_size_1{0,0,0}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + typename OutputOp0::Params epilogue0 {}; + typename OutputOp1::Params epilogue1 {}; + int batch_count{1}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + typename B2bMma::IteratorA0::TensorRef ref_A0_, + typename B2bMma::IteratorB0::TensorRef ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef ref_D1_, + int64_t batch_stride_A0_, + int64_t batch_stride_B0_, + int64_t batch_stride_B1_, + int64_t batch_stride_C1_, + int64_t batch_stride_D1_, + int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int batch_count_ = 1 + ): + mode(mode_), + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + batch_count(batch_count_) { + } + }; + + // Arguments structure for grouped B2B problems + struct GroupedArguments { + GemmCoord* problem_size_0; + GemmCoord* problem_size_1; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + + int problem_count; + int threadblock_count; + GemmCoord* host_problem_sizes; + + CUTLASS_HOST_DEVICE + GroupedArguments( + int problem_count, + GemmCoord* problem_size_0_, + GemmCoord* problem_size_1_, + typename B2bMma::IteratorA0::TensorRef* ref_A0_, + typename B2bMma::IteratorB0::TensorRef* ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef* ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_D1_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int threadblock_count = 0 + ) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_), + ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_), + ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_), + problem_count(problem_count), + threadblock_count(threadblock_count) + {} + }; + + /// Parameters structure + struct Params { + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmCoord problem_size_0{}; + cutlass::gemm::GemmCoord problem_size_1{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + typename B2bMma::IteratorA0::Params params_A0{}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::Params params_B0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::Params params_C0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::Params params_B1{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::Params params_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::Params params_D1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + typename OutputOp0::Params output_op_0{}; + typename OutputOp1::Params output_op_1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + int *semaphore = nullptr; + int gemm_k_iterations_0{0}; + int gemm_k_size_0{0}; + int gemm_k_iterations_1{0}; + int gemm_k_size_1{0}; + + // + // Methods + // + + Params() = default; + + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmUniversalMode mode, + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1, + int64_t batch_stride_A0, + int64_t batch_stride_B0, + int64_t batch_stride_B1, + int64_t batch_stride_C1, + int64_t batch_stride_D1, + int64_t batch_stride_Bias0, + int64_t batch_stride_Scale0, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + int *workspace = nullptr + ): + mode(mode), + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)), + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + batch_stride_A0(batch_stride_A0), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C1(batch_stride_C1), + batch_stride_D1(batch_stride_D1), + batch_stride_Bias0(batch_stride_Bias0), + batch_stride_Scale0(batch_stride_Scale0), + output_op_0(output_op_0), + output_op_1(output_op_1) { + + int total_gemm_k_iterations_0 = (problem_size_0.k() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + int gemm_k_iterations_0 = (total_gemm_k_iterations_0 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_0 = gemm_k_iterations_0 * B2bMma::Shape0::kK; + int total_gemm_k_iterations_1 = (problem_size_1.k() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + int gemm_k_iterations_1 = (total_gemm_k_iterations_1 + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size_1 = gemm_k_iterations_1 * B2bMma::Shape1::kK; + + semaphore = workspace; + } + }; + + struct GroupedParams { + cutlass::gemm::GemmCoord* problem_size_0; + cutlass::gemm::GemmCoord* problem_size_1; + cutlass::gemm::GemmCoord* grid_tiled_shape; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problems in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + + using ProblemVisitor = typename detail::ProblemVisitorOrDefault::value; + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int* workspace; + + CUTLASS_HOST_DEVICE + GroupedParams() {} + + CUTLASS_HOST_DEVICE + GroupedParams( + GroupedArguments const &args, + void *workspace = nullptr, + int tile_count = 0 + ) : + problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1), + ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0), + ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1), + output_op_0(args.epilogue0), output_op_1(args.epilogue1), + problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + workspace(reinterpret_cast(workspace)) {} + + CUTLASS_HOST_DEVICE + void transpose() { + // Only row-major outputs are currently supported, so no transpose is performed + } + + /// Returns non-grouped parameters to be used as input to the kernel-level + /// operator for the problem indicated by problem_visitor. + CUTLASS_HOST_DEVICE + Params to_single_params(const ProblemVisitor& problem_visitor) const { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + int32_t idx = problem_visitor.problem_index(); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1); + + return Params( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size0, + problem_size1, + grid_shape, + ref_A0[idx], + ref_B0[idx], + ref_C0[idx], + ref_Scale0[idx], + ref_Bias0[idx], + ref_B1[idx], + ref_C1[idx], + ref_D1[idx], + 0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported + output_op_0, + output_op_1, + workspace + ); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size_0, + cutlass::gemm::GemmCoord const & problem_size_1, + typename B2bMma::IteratorA0::TensorRef ref_A0, + typename B2bMma::IteratorB0::TensorRef ref_B0, + typename Epilogue::OutputTileIterator::TensorRef ref_C0, + typename B2bMma::IteratorB1::TensorRef ref_B1, + typename Epilogue::OutputTileIterator::TensorRef ref_C1, + typename Epilogue::OutputTileIterator::TensorRef ref_D1) { + + static int const kAlignmentA = B2bMma::IteratorA0::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if ((problem_size_0.m() % kAlignmentA) || (problem_size_0.k() % kAlignmentA) || + (problem_size_0.n() % kAlignmentB) || (problem_size_0.k() % kAlignmentB) || + (problem_size_0.m() % kAlignmentC) || (problem_size_0.n() % kAlignmentC) || + (problem_size_1.m() % kAlignmentA) || (problem_size_1.k() % kAlignmentA) || + (problem_size_1.n() % kAlignmentB) || (problem_size_1.k() % kAlignmentB) || + (problem_size_1.m() % kAlignmentC) || (problem_size_1.n() % kAlignmentC)) { + + return Status::kErrorMisalignedOperand; + } + + // Determine if fusion sizes are valid + if(problem_size_0.m() != problem_size_1.m()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() != problem_size_1.k()) + return Status::kErrorInvalidProblem; + + if(problem_size_0.n() > B2bMma::Shape0::kN) + return Status::kErrorInvalidProblem; + + if(problem_size_1.n() > B2bMma::Shape1::kN) + return Status::kErrorInvalidProblem; + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + ElementA0 *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0 *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1 *ptr_B1 = static_cast(params.ref_B1.data()); + + ScaleBiasData *ptr_Bias0 = static_cast(params.ref_Bias0.data()); + ScaleBiasData *ptr_Scale0 = static_cast(params.ref_Scale0.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = min( + problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = min( + problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } + + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0; + ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * B2bMma::Shape0::kM, + offset_k_0, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k_0, + threadblock_tile_offset.n() * B2bMma::Shape0::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k_1, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + }; + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; + + // Compute threadblock-scoped matrix multiply-add + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.params_A0, + ptr_A0, + {params.problem_size_0.m(), problem_size_k_0}, + thread_idx, + tb_offset_A0); + + typename B2bMma::IteratorB0 iterator_B0( + params.params_B0, + ptr_B0, + {problem_size_k_0, params.problem_size_0.n()}, + thread_idx, + tb_offset_B0); + + typename B2bMma::IteratorB1 iterator_B1( + params.params_B1, + ptr_B1, + {problem_size_k_1, params.problem_size_1.n()}, + thread_idx, + tb_offset_B1); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + ptr_Scale0, + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + ptr_Bias0, + {1, params.problem_size_0.n()}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_offset.n() * B2bMma::Shape0::kN + ) + ); + + // + // Main loop + // + + OutputOp0 output_op_0(params.output_op_0); + + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle::value) { + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + } + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); + + // + // Epilogue + // + + OutputOp1 output_op_1(params.output_op_1); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * B2bMma::Shape1::kM, + threadblock_tile_offset.n() * B2bMma::Shape1::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; + } + + // Tile iterator loading from source tensor. + typename Epilogue::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue::OutputTileIterator iterator_D1( + params.params_D1, + ptr_D1, + params.problem_size_1.mn(), + thread_idx, + threadblock_offset + ); + + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..cc35d91be6d8b780a7cfb1235e0f5a190ec843ab --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped B2b GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + static bool const kTransposed = Transposed; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the B2b-GEMM-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes0 because these determine + // shape of the grid used in the non-grouped B2b GEMM + problem_sizes0, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + B2bGemmGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h new file mode 100644 index 0000000000000000000000000000000000000000..6794fcc1d977d2ba92fa1d678d327a1d8db39ee4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h @@ -0,0 +1,521 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined Implicit GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct B2bImplicitGemmConvolution { + + using B2bMma = B2bMma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp0 = typename B2bMma::OutputOp; + using EpilogueOutputOp1 = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + using ElementC = typename EpilogueOutputOp1::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp0::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp0::ElementCompute; + + /// Scale and Bias + using ElementScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Element; + using LayoutScaleBias = typename B2bMma::IteratorAccumulatorScaleBias::Layout; + + using WarpMmaOperator0 = typename B2bMma::Policy0::Operator; + using WarpMmaOperator1 = typename B2bMma::Policy1::Operator; + + using ArchMmaOperator = typename WarpMmaOperator0::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator0::OperatorClass; + using ArchTag = typename WarpMmaOperator0::ArchTag; + + using ThreadblockShape0 = typename B2bMma::Shape0; + using ThreadblockShape1 = typename B2bMma::Shape1; + using WarpShape0 = typename WarpMmaOperator0::Shape; + using WarpShape1 = typename WarpMmaOperator1::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = B2bMma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = B2bMma::IteratorA0::kIteratorAlgorithm; + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename B2bMma::WarpCount0; + static int const kThreadCount = 32 * WarpCount0::kCount; + + using TensorRefA0 = typename B2bMma::IteratorA0::TensorRef; + using TensorRefB0 = typename B2bMma::IteratorB0::TensorRef; + using TensorRefScaleBias0 = typename B2bMma::IteratorAccumulatorScaleBias::TensorRef; + using TensorRefB1 = typename B2bMma::IteratorB1::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::B2bImplicitGemmConvolution::kConvDim + static_assert(B2bMma::IteratorA0::kConvDim == B2bMma::IteratorB0::kConvDim, + "Convolution on different dimensions is not supported"); + static int const kConvDim = B2bMma::IteratorA0::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + cutlass::platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size_0; + ConvProblemSize problem_size_1; + TensorRefA0 ref_A0; + TensorRefB0 ref_B0; + TensorRefC ref_C0; + TensorRefScaleBias0 ref_Scale0; + TensorRefScaleBias0 ref_Bias0; + TensorRefB1 ref_B1; + TensorRefC ref_C1; + TensorRefC ref_D1; + typename EpilogueOutputOp0::Params output_op_0; + typename EpilogueOutputOp1::Params output_op_1; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size_0, + ConvProblemSize const & problem_size_1 + ): + problem_size_0(problem_size_0), + problem_size_1(problem_size_1) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size_0, + ConvProblemSize const & problem_size_1, + TensorRefA0 const & ref_A0, + TensorRefB0 const & ref_B0, + TensorRefC const & ref_C0, + TensorRefScaleBias0 const & ref_Scale0, + TensorRefScaleBias0 const & ref_Bias0, + TensorRefB1 const & ref_B1, + TensorRefC const & ref_C1, + TensorRefC const & ref_D1, + typename EpilogueOutputOp0::Params const & output_op_0, + typename EpilogueOutputOp1::Params const & output_op_1, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size_0(problem_size_0), + problem_size_1(problem_size_1), + ref_A0(ref_A0), + ref_B0(ref_B0), + ref_C0(ref_C0), + ref_Scale0(ref_Scale0), + ref_Bias0(ref_Bias0), + ref_B1(ref_B1), + ref_C1(ref_C1), + ref_D1(ref_D1), + output_op_0(output_op_0), + output_op_1(output_op_1), + split_k_mode(split_k_mode) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size_0; + ConvProblemSize problem_size_1; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size_0; + gemm::GemmCoord implicit_gemm_problem_size_1; + int swizzle_log_tile; + int gemm_k_iterations_0; + int gemm_k_iterations_1; + typename B2bMma::IteratorA0::Params iterator_A0; + typename B2bMma::IteratorA0::Element const *ptr_A0; + typename B2bMma::IteratorB0::Params iterator_B0; + typename B2bMma::IteratorB0::Element const *ptr_B0; + typename Epilogue::OutputTileIterator::Params iterator_C0; + typename Epilogue::OutputTileIterator::Element *ptr_C0; + typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::Element *ptr_Bias0; + typename B2bMma::IteratorB1::Params iterator_B1; + typename B2bMma::IteratorB1::Element const *ptr_B1; + typename Epilogue::OutputTileIterator::Params iterator_C1; + typename Epilogue::OutputTileIterator::Element *ptr_C1; + typename Epilogue::OutputTileIterator::Params iterator_D1; + typename Epilogue::OutputTileIterator::Element *ptr_D1; + typename EpilogueOutputOp0::Params output_op_0; + typename EpilogueOutputOp1::Params output_op_1; + int *semaphore; + SplitKMode split_k_mode; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), gemm_k_iterations_0(0), gemm_k_iterations_1(0) { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size_0(args.problem_size_0), + problem_size_1(args.problem_size_1), + implicit_gemm_problem_size_0(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_0)), + implicit_gemm_problem_size_1(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size_1)), + iterator_A0(B2bMma::IteratorA0::getParams(args.problem_size_0, args.ref_A0.layout())), + ptr_A0(args.ref_A0.data()), + iterator_B0(args.problem_size_0, args.ref_B0.layout()), + ptr_B0(args.ref_B0.data()), + iterator_C0(ConvOutputIteratorParameter::layout(args.ref_C0)), + ptr_C0(args.ref_C0.data()), + ptr_Scale0(args.ref_Scale0.data()), + ptr_Bias0(args.ref_Bias0.data()), + iterator_B1(args.problem_size_1, args.ref_B1.layout()), + ptr_B1(args.ref_B1.data()), + iterator_C1(ConvOutputIteratorParameter::layout(args.ref_C1)), + ptr_C1(args.ref_C1.data()), + iterator_D1(ConvOutputIteratorParameter::layout(args.ref_D1)), + ptr_D1(args.ref_D1.data()), + output_op_0(args.output_op_0), + output_op_1(args.output_op_1), + semaphore(semaphore), + split_k_mode(args.split_k_mode) + { + gemm_k_iterations_0 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape0::kK, args.problem_size_0); + gemm_k_iterations_1 = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape1::kK, args.problem_size_1); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size_0, + {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, + args.problem_size_0.split_k_slices); + + swizzle_log_tile = ThreadblockSwizzle().get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename B2bMma::B2bMmaSharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + B2bImplicitGemmConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename B2bMma::IteratorA0 iterator_A0( + params.iterator_A0, + params.problem_size_0, + params.ptr_A0, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * B2bMma::Shape0::kM, + threadblock_tile_idx.k() * B2bMma::Shape0::kK + ) + ); + + typename B2bMma::IteratorB0 iterator_B0( + params.iterator_B0, + params.problem_size_0, + params.ptr_B0, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * B2bMma::Shape0::kK, + threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorB1 iterator_B1( + params.iterator_B1, + params.problem_size_1, + params.ptr_B1, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * B2bMma::Shape1::kK, + threadblock_tile_idx.n() * B2bMma::Shape1::kN + ) + ); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // Construct iterators to accumulator scale/bias vector + typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( + params.ptr_Scale0, + {1, params.problem_size_0.K}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( + params.ptr_Bias0, + {1, params.problem_size_0.K}, + thread_idx, + warp_idx, + MatrixCoord( + 0, threadblock_tile_idx.n() * B2bMma::Shape0::kN + ) + ); + + + // + // Main loop + // + + EpilogueOutputOp0 output_op_0(params.output_op_0); + + // Construct thread-scoped matrix multiply + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename B2bMma::FragmentC0 src_accum; + typename B2bMma::FragmentC1 accumulators; + + src_accum.clear(); + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); + + // + // Epilogue + // + + EpilogueOutputOp1 output_op_1(params.output_op_1); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * B2bMma::Shape1::kM, + threadblock_tile_idx.n() * B2bMma::Shape1::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D1( + params.iterator_D1, + params.ptr_D1, + ConvOutputIteratorParameter::extent(params.problem_size_1), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C1( + params.iterator_C1, + params.ptr_C1, + ConvOutputIteratorParameter::extent(params.problem_size_1), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_idx.k()); + + __threadfence(); + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D1.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size_1)); + } + + // Run efficient epilogue + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h new file mode 100644 index 0000000000000000000000000000000000000000..46254b5343936d2b558dd92adc0e253ee4ce2696 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined.h" +#include "threadblock/b2b_implicit_gemm_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + bool SmemAccumulator = false +> struct DefaultB2bConv2dFprop; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..dbb21aecd6544c87905832418513029e58e6a33e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h @@ -0,0 +1,749 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + false +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelined< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..6e0af4f96592a835a4ca818b13b658c9947d7cd8 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h @@ -0,0 +1,740 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +// multistage pipeline with interleaved layout. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp0>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistage< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + ThreadblockShape1, + FragmentIteratorA1, + IteratorAccumulatorScaleBias, + FragmentIteratorA1ScaleBias, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h new file mode 100644 index 0000000000000000000000000000000000000000..35a5681dd429f2497a20820dfaafa297c2803c39 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h @@ -0,0 +1,817 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm +/// and 2 stage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm +/// and 2 stage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and 2 stage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + 2, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + > + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + > + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + > + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmPipelinedSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + IteratorB0, + SmemIteratorB0, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + ElementC, + LayoutC, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1 + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h new file mode 100644 index 0000000000000000000000000000000000000000..ac80be8e77989bb6b53b35cb5f6a67004040aa87 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h @@ -0,0 +1,804 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "kernel/default_b2b_conv2d_fprop.h" +#include "kernel/b2b_implicit_gemm_convolution.h" +#include "threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +/// multistage pipeline. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultB2bConv2dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and +// multistage pipeline with interleaved layout. +/// Accumulator will be staged in shared memory. +template < + typename ElementA, + typename ElementB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape0, + typename ThreadblockShape1, + typename WarpShape0, + typename WarpShape1, + typename InstructionShape, + typename EpilogueOutputOp0, + typename EpilogueOutputOp1, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int InterleavedK +> +struct DefaultB2bConv2dFprop < + ElementA, + layout::TensorNCxHWx, + ElementB, + layout::TensorCxRSKx, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + true +> { + + // Define the core components from GEMM + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, layout::ColumnMajorInterleaved, + ElementB, layout::RowMajorInterleaved, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, + Stages, MathOperatorTag, true>; + + // Define iterators over tiles from the A operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapA0 = typename MmaCore0::SmemThreadMapA; + using IteratorA0 = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, layout::TensorNCxHWx, + ThreadMapA0 + >; + + using SmemIteratorA0 = typename MmaCore0::SmemIteratorA; + + // Define iterators over tiles from the B operand + // Note GEMM shared memory threadmap is used here because conv global memory + // layout needs to be mapped to fprop which is similar to the crosswise + // layout which is used by the interleaved GEMM shared memory threadmap. + // The Interleaved GEMM global memory layout is similar to the congruous + // layout. + using ThreadMapB0 = typename MmaCore0::SmemThreadMapB; + using IteratorB0 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB0 + >; + + using SmemIteratorB0 = typename MmaCore0::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp0::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + using ThreadMapB1 = typename MmaCore1::SmemThreadMapB; + using IteratorB1 = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, layout::TensorCxRSKx, + ThreadMapB1 + >; + + using SmemIteratorB1 = typename MmaCore1::SmemIteratorB; + + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + ElementC, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileIteratorCanonical< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + + // Define the Mma + using B2bMma = threadblock::B2bImplicitGemmMultistageSmemAccumulator< + ThreadblockShape0, + IteratorA0, + SmemIteratorA0, + arch::CacheOperation::Always, + IteratorB0, + SmemIteratorB0, + arch::CacheOperation::Global, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, + SmemIteratorD0, + ThreadblockShape1, + WarpIteratorA1, + IteratorB1, + SmemIteratorB1, + arch::CacheOperation::Global, + EpilogueOutputOp0, + MmaPolicy0, + MmaPolicy1, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultInterleavedConvEpilogue< + ThreadblockShape1, + WarpMmaTensorOp1, + 1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount, + InterleavedK + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::B2bImplicitGemmConvolution< + B2bMma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..e2cc94377c0e972ee39fc051c6798368c69d2041 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -0,0 +1,503 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" + +#include "kernel/b2b_gemm.h" +#include "kernel/grouped.h" +#include "threadblock/default_b2b_mma.h" +#include "threadblock/grouped_threadblock_swizzle.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +template +using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle; + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Stage accumulator in shared memory + bool SmemAccumulator = false, + /// Whether or not the operation is grouped + typename Enable = void +> +struct DefaultB2bGemm; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +/// Partial specialization for Ampere Architecture with grouped operation +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using UnderlyingB2bGemmKernel = kernel::B2bGemm; + + using B2bGemmKernel = kernel::GroupedKernel; +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Operation performed by GEMM + typename Operator +> +struct DefaultB2bGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + Operator, + false, + typename platform::enable_if::value>::type +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + 2, + Operator, + EpilogueOutputOp0 + >::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + typename B2bMma::Operator1, + kPartitionsK1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, Stages, + Operator, false, typename platform::enable_if::value>::type> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, + true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Turing Integer Tensor Core Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, 2, Operator, false, + typename platform::enable_if::value>::type> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, + WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue for the 2nd Gemm + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..0a4530f6ce82c18656d7bb64452ca08b1eebfe18 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_pipelined.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" + +#include "kernel/b2b_gemm.h" +#include "threadblock/default_b2b_mma.h" +#include "threadblock/default_b2b_mma_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Ampere Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for Turing Architecture +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Operation performed by GEMM + typename Operator +> +struct DefaultB2bGemm< + ElementA, LayoutA, kAlignmentA, + ElementB, LayoutB, kAlignmentB, + ElementC, layout::RowMajor, + ElementAccumulator, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + ThreadblockSwizzle, + 2, + Operator, + true +> { + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm75, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + 2, + Operator, + EpilogueOutputOp0, + false, + true + >::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, + typename B2bMma::Operator1, + kPartitionsK1, + EpilogueOutputOp1, + EpilogueOutputOp1::kCount + >::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + + +/// Partial specialization for Ampere Integer Matrix Multiply Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm< + ElementA, layout::ColumnMajorInterleaved, kAlignmentA, + ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, int32_t, + arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, Stages, + Operator, true> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0, + true, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +/// Partial specialization for Turing Integer Tensor Core Interleaved layout +template < + /// Element type for A matrix operand + typename ElementA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of Interleaved k + int InterleavedK, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm, + kAlignmentA, ElementB, + layout::RowMajorInterleaved, kAlignmentB, + ElementC, layout::ColumnMajorInterleaved, + int32_t, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, + ThreadblockSwizzle, 2, Operator, true> { + using LayoutA = layout::ColumnMajorInterleaved; + using LayoutB = layout::RowMajorInterleaved; + using LayoutC = layout::ColumnMajorInterleaved; + + using ElementAccumulator = int32_t; + + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue for the 2nd Gemm + using Epilogue = typename cutlass::epilogue::threadblock:: + DefaultInterleavedEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + 64 / sizeof_bits::value, InterleavedK>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..0ac841d4edf08b2d846109f64e343525165b07b9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/kernel/grouped.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief High-level interface for running a grouped version of a CUTLASS kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// High-level interface for running a grouped version of a CUTLASS kernel +template < + typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate +> +struct GroupedKernel { +public: + + using BaseKernel = BaseKernel_; + using Epilogue = typename BaseKernel::Epilogue; + + /// Types that need to be exported to work properly with device::BaseGrouped + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + using Mma = typename BaseKernel::Mma; + + using Arguments = typename BaseKernel::GroupedArguments; + using Params = typename BaseKernel::GroupedParams; + using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor; + + static int const kThreadCount = BaseKernel::kThreadCount; + + /// Shared memory storage structure + struct SharedStorage { + typename BaseKernel::SharedStorage kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GroupedKernel() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + /// Executes a kernel-level GEMM in a loop + CUTLASS_DEVICE + void operator()(Params ¶ms, SharedStorage &shared_storage) { + + ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + if (ProblemVisitor::kTransposed) { + params.transpose(); + } + + BaseKernel mma; + + // Outer 'persistent' loop to iterate over tiles + while (swizzle.problem_visitor.next_tile()) { + + typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor); + mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle); + + // Next tile + swizzle.problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..4bf3c532559cb27b7d638319276b447817a3cc49 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h @@ -0,0 +1,395 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Defines device-side elementwise operations on TensorView. Note, the operations defined + in this header are not specialized for any particular data layout and are therefore not + intended to offer the best possible performance. Rather, they are intended to be generic + reference implementations to support the CUTLASS unit tests. +*/ + +#pragma once + +// Cutlass includes +#include "cutlass/cutlass.h" +#include "cutlass/tensor_view.h" + +#include "cutlass/gemm/gemm.h" + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace reference { +namespace device { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename OutputTile, + typename ConvertOp = NumericConverter +> +__global__ void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + MatrixCoord output_coord( + MatrixCoord::Index((threadIdx.x + blockIdx.x * blockDim.x) * OutputTile::kRow), + MatrixCoord::Index((threadIdx.y + blockIdx.y * blockDim.y) * OutputTile::kColumn) + ); + + // Update the output tensor + for (int j = 0; j < OutputTile::kRow; ++j) { + for (int i = 0; i < OutputTile::kColumn; ++i) { + MatrixCoord coord = output_coord + MatrixCoord(i, j); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } +} + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kMblock = 4, + int kNblock = 4 +> +__global__ void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + ConvertOp convert_op; + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in); + tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out); + tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale); + tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + MatrixCoord coord = MatrixCoord(row, col); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } + tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z); + tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z); + tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z); + tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z); + } +} + +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kThreadM = 4, // shape of a thread's tile in the GEMM M dimension + int kThreadN = 4, // shape of a thread's tile in the GEMM N dimension + int kCtaShapeM = 16, // shape of a threadblock in units of threads + int kCtaShapeN = 8 // shape of a threadblock in units of threads +> +__global__ void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias ///< bias tensor +) { + + ConvertOp convert_op; + + int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; + int k_start = blockIdx.y * kCtaShapeN * kThreadN + threadIdx.y * kThreadN; + + int thread_n[kThreadM]; + int thread_p[kThreadM]; + int thread_q[kThreadM]; + + // Compute N, P, Q coordinates for each row of a thread's tile + int64_t PQ = int64_t(problem_size.P) * problem_size.Q; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + + int64_t npq = npq_start + m; + + thread_n[m] = int(npq / PQ); + + int64_t residual = npq % PQ; + thread_p[m] = int(residual / problem_size.Q); + thread_q[m] = int(residual % problem_size.Q); + } + + // Write out the results + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < kThreadM; ++m) { + if (thread_n[m] < problem_size.N && thread_p[m] < problem_size.P && thread_q[m] < problem_size.Q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < kThreadN; ++n) { + int thread_k = k_start + n; + if (thread_k < problem_size.K) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, thread_k}); + + ScalarType bias = ScalarType(0); + if(tensor_bias.good()) + bias = tensor_bias.at({0, thread_k}); + + tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( + scale * ScalarType( + tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) + ) + bias); + } + } + } + } + +} + +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemm( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + using OutputTile = MatrixShape<4, 4>; + + dim3 block(16, 8); + + dim3 grid( + (problem_size.m() + block.x * OutputTile::kRow - 1) / (block.x * OutputTile::kRow), + (problem_size.n() + block.y * OutputTile::kColumn - 1) / (block.y * OutputTile::kColumn) + ); + + kernel::TensorScaleBiasGemm< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + OutputTile, + ConvertOp + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::TensorScaleBiasGemmBatched< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kMblock, + kNblock + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias, + batch_count, + batch_stride_tensor_in, + batch_stride_tensor_out, + batch_stride_tensor_scale, + batch_stride_tensor_bias + ); +} + +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasConv2d( + conv::Conv2dProblemSize problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias ///< bias tensor +) { + + int const kThreadM = 4; // shape of a thread's tile in the GEMM M dimension + int const kThreadN = 4; // shape of a thread's tile in the GEMM N dimension + int const kCtaShapeM = 16; // shape of a threadblock in units of threads + int const kCtaShapeN = 8; // shape of a threadblock in units of threads + + int64_t npq = int64_t(problem_size.N) * problem_size.P * problem_size.Q; + int64_t blocks_m = (npq + (kCtaShapeM * kThreadM) - 1) / (kCtaShapeM * kThreadM); + + dim3 block(kCtaShapeM, kCtaShapeN); + dim3 grid(uint32_t(blocks_m), (problem_size.K + (kCtaShapeN * kThreadN) - 1) / (kCtaShapeN * kThreadN)); + + + kernel::TensorScaleBiasConv2d< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kThreadM, + kThreadN, + kCtaShapeM, + kCtaShapeN + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias + ); +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace device +} // namespace reference +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h new file mode 100644 index 0000000000000000000000000000000000000000..1fba44d69655d8d8f14442c411da4cafd3db598f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (!(props.major == arch_major && props.minor == arch_minor)) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..1feb71cf11ceb30517e99dc25275644e6060b2db --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h @@ -0,0 +1,831 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// WarpIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bImplicitGemmMultistage : + public gemm::threadblock::B2bMmaBase { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of A operand in global memory + using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from threadblock fragment + using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + using ElementC = typename Policy0::Operator::ElementC; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0( + IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + + iterator_A0.set_iteration_index(group_start_A0); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + + if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0); + + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + + iterator_B1.set_iteration_index(group_start_B1); + + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_0(iterator_A0, iterator_B0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A0, group_start_iteration_B0; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { + group_start_iteration_A0 = 0; + group_start_iteration_B0 = 0; + } else { + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + } + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + + // 2nd Implicit Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + + + // + // Prologue + // + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); + ++warp_tile_iterator_A1_bias_; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0], + warp_loaded_frag_A1_bias[0], + output_op_0); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_1(iterator_B1); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load threadblock-level scale/bias vector from global memory + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + // Load warp-level scale bias fragment from threadblock scale/bias vector + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_bias_; + + // Load warp-level tile from accumulator fragment + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_B1; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + group_start_iteration_B1 = 0; + } else { + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + } + + copy_tiles_and_advance_1(iterator_B1, + group_start_iteration_B1); + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..64181870b5a76afd7112d74f5f0dc1913ada6323 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h @@ -0,0 +1,816 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bImplicitGemmMultistageSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + using ElementC = typename Policy0::Operator::ElementC; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (AsyncCopyIterationsPerStageA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (AsyncCopyIterationsPerStageB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (AsyncCopyIterationsPerStageB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmMultistageSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0( + IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + + iterator_A0.set_iteration_index(group_start_A0); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + + if (group_start_A0 + j < Detail::AsyncCopyIterationsPerStageA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0); + + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::AsyncCopyIterationsPerStageB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + + iterator_B1.set_iteration_index(group_start_B1); + + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::AsyncCopyIterationsPerStageB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0, + ///< Imaginary strides used for planar-complex only - ignored here + int64_t imag_stride_A = 0, + int64_t imag_stride_B = 0) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_0(iterator_A0, iterator_B0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A0, group_start_iteration_B0; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) { + group_start_iteration_A0 = 0; + group_start_iteration_B0 = 0; + } else { + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + } + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.advance(); + iterator_B0.advance(); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + // 2nd Implicit Gemm + + // + // Prologue + // + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + // Start issuing the first group of the next stage outside of the mainloop + copy_tiles_and_advance_1(iterator_B1); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load warp-level tile from accumulator fragment + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_B1; + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + group_start_iteration_B1 = 0; + } else { + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + } + + copy_tiles_and_advance_1(iterator_B1, + group_start_iteration_B1); + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.advance(); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + } + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..97466a1ce72e632f6436b49a3f0c578b03f543da --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + // (concept: VectorFragmentIterator) + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Transformation applied to A operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bImplicitGemmPipelined : + public gemm::threadblock::B2bMmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::B2bMmaBase; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over tiles of A1 operand from accumulator tile + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using FragmentIteratorA1ScaleBias = + FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + + using SmemIteratorB1 = SmemIteratorB1_; + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment + using WarpFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmPipelined( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + //These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + + } + } + + + //2nd Implicit Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + + + // + // Prologue + // + + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + FragmentB1 tb_frag_B1; + + if(PerChannelScale) + tb_frag_A1_scale.clear(); + tb_frag_A1_bias.clear(); + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + if(PerChannelScale) + iterator_A1_scale.load(tb_frag_A1_scale); + iterator_A1_bias.load(tb_frag_A1_bias); + iterator_B1.load(tb_frag_B1); + + + if(PerChannelScale) + ++iterator_A1_scale; + ++iterator_A1_bias; + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..e5d91b127ca45d99a57f1cc656d59e491335dcbe --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bImplicitGemmPipelinedSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + using WarpFragmentA1 = typename Operator1::FragmentA; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bImplicitGemmPipelinedSmemAccumulator( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + + } + } + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + /// 2nd Implicit Gemm + + + // + // Prologue + // + + FragmentB1 tb_frag_B1; + + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_frag_A1[0]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..c845f2023f83e09303efc0df2dc0d560b5f080ca --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + using Shape1 = Shape1_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm0 = typename Policy0::Operator::Shape; + using WarpGemm1 = typename Policy1::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount0 = GemmShape; + using WarpCount1 = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations0 = + (WarpGemm0::kK / Operator0::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = + (WarpGemm1::kK / Operator1::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template< + typename Shape_, + typename Policy_ + > + class SharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Policy = Policy_; + using Operator = typename Policy::Operator; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + using SharedStorage0 = SharedStorage; + using SharedStorage1 = SharedStorage; + union B2bMmaSharedStorage { + SharedStorage0 shared_storage0; + SharedStorage1 shared_storage1; + }; + + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A0 operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A0_; + + /// Iterator to load a warp-scoped tile of B0 operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + + /// Iterator to load a warp-scoped tile of B1 operand from shared memory + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..c0356df46fb18cedbf984a0058e7d680e2a6650a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h @@ -0,0 +1,179 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "threadblock/b2b_mma_base.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Shared Memory Accumulator Iterator + typename SmemAccumulatorIterator0_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaBaseSmemAccumulator : + public B2bMmaBase { + + public: + ///< Base class + using Base = B2bMmaBase; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + using Shape1 = Shape1_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + + using SmemAccumulatorIterator0 = SmemAccumulatorIterator0_; + + // + // Nested structs + // + /// Shared storage object needed by accumulator + template< + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_ + > + class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + AlignedBuffer accum; + + public: + + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } + + }; + + using AccumulatorSharedStorage0 = AccumulatorSharedStorage< + Shape0, typename SmemAccumulatorIterator0::Element, + typename SmemAccumulatorIterator0::TensorLayout, + typename SmemAccumulatorIterator0::Padding>; + + struct B2bMmaSharedStorage { + typename Base::B2bMmaSharedStorage b2b_mma_shared_storage; + AccumulatorSharedStorage0 accumulator_shared_storage0; + }; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaBaseSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage.b2b_mma_shared_storage, thread_idx, warp_idx, lane_idx) { + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..b9388a73316d4c806ab62a548e67312008ab9527 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -0,0 +1,903 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// WarpIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaMultistage : + public B2bMmaBase { +public: + ///< Base class + using Base = B2bMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over intermediate accumulator tile + using FragmentIteratorA1 = FragmentIteratorA1_; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< WarpIterator to load Scale or Bias vector from threadblock fragment + using FragmentIteratorA1ScaleBias = FragmentIteratorA1ScaleBias_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + + using SmemIteratorB1 = SmemIteratorB1_; + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const TBLoadIterationsA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpLoadedFragmentA1 = typename FragmentIteratorA1::Fragment; + using WarpLoadedFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations0 * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + iterator_A0.set_iteration_index(group_start_A0 * + IteratorA0::kAccessesPerVector); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // Load for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0 * + IteratorB0::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0) + { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // Load for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { + int group_start_iteration_A0, group_start_iteration_B0; + + group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + int group_start_iteration_A0, group_start_iteration_B0; + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + // 2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + + // + // Prologue + // + int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentA1ScaleBias warp_loaded_frag_A1_bias[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[0]); + ++warp_tile_iterator_A1_bias_; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0], + warp_loaded_frag_A1_scale[0], + warp_loaded_frag_A1_bias[0], + output_op_0); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + // + // Mainloop + // + + gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load threadblock-level scale/bias vector from global memory + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) { + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + // Load warp-level scale bias fragment from threadblock scale/bias vector + if(PerChannelScale) { + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_scale_; + } + warp_tile_iterator_A1_bias_.load(warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2]); + ++warp_tile_iterator_A1_bias_; + + // Load warp-level tile from accumulator fragment + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..b089bdba2bea2069dd9755b09e680278a771fe2a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -0,0 +1,887 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA0, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB0, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class B2bMmaMultistageSmemAccumulator : + public gemm::threadblock::B2bMmaBaseSmemAccumulator { +public: + ///< Base class + using Base = gemm::threadblock::B2bMmaBaseSmemAccumulator; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape0 = Shape0_; + ///< Iterates over tiles of A operand in global memory + using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; + ///< Iterates over tiles of B operand in global memory + using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; + ///< Iterates over tiles of the scale and bias vectors in global memory + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + ///< Epilogue after 1st Gemm + using OutputOp = OutputOp_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA0 = CacheOpA0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB0 = CacheOpB0; + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// Internal structure exposed for introspection. + struct Detail { + + static_assert(Base::kWarpGemmIterations0 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + static_assert(Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const TBLoadIterationsA0 = + IteratorA0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB0 = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA0 = + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB0 = + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + }; + + private: + + using WarpLoadedFragmentA0 = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA0 = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A0_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaMultistageSmemAccumulator( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::B2bMmaSharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset( + {warp_idx_m_0, Base::kWarpGemmIterations0 * warp_idx_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations0 * warp_idx_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_0(IteratorA0 &iterator_A0, IteratorB0 &iterator_B0, + int group_start_A0 = 0, int group_start_B0 = 0) { + iterator_A0.set_iteration_index(group_start_A0 * + IteratorA0::kAccessesPerVector); + this->smem_iterator_A0_.set_iteration_index(group_start_A0); + + // cp.async for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + } + + iterator_B0.set_iteration_index(group_start_B0 * + IteratorB0::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1(IteratorB1 &iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index(group_start_B1 * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_0, + ///< destination accumulator tile + FragmentC1 &accum, + ///< iterator over A0 operand in global memory + IteratorA0 iterator_A0, + ///< iterator over B0 operand in global memory + IteratorB0 iterator_B0, + ///< iterator over A1 operand scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, + ///< iterator over A1 operand bias vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC0 const &src_accum, + ///< epilogue operation after 1st Gemm + OutputOp output_op_0) + { + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_0) { + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + iterator_A0.set_iteration_index(0); + this->smem_iterator_A0_.set_iteration_index(0); + + // cp.async for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { + typename IteratorA0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA0::ThreadMap::kElementsPerAccess / + IteratorA0::kAccessesPerVector / 8; + + int src_bytes = (iterator_A0.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A0.get(), iterator_A0.valid()); + + ++iterator_A0; + } + + ++this->smem_iterator_A0_; + } + + iterator_B0.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA0 warp_loaded_frag_A0[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpTransformedFragmentA0 warp_transformed_frag_A0[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + + Operator0 warp_mma0; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A0[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A0[0], warp_loaded_frag_B0[0]); + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_loaded_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k > 0) + warp_mma0.transform(warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A0[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + + warp_mma0( + accum0, + warp_transformed_frag_A0[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations0 - 1) { + int group_start_iteration_A0, group_start_iteration_B0; + + group_start_iteration_A0 = warp_mma_k * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = warp_mma_k * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations0) { + int group_start_iteration_A0, group_start_iteration_B0; + group_start_iteration_A0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA0; + group_start_iteration_B0 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB0; + + copy_tiles_and_advance_0(iterator_A0, iterator_B0, group_start_iteration_A0, + group_start_iteration_B0); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A0.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + + this->smem_iterator_A0_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A0_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations0, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations_0; + iterator_A0.clear_mask(gemm_k_iterations_0 == 0); + iterator_B0.clear_mask(gemm_k_iterations_0 == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations0) + warp_mma0.transform(warp_transformed_frag_A0[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A0[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + } + + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + + // 2nd Gemm + + // + // Prologue + // + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations_1) { + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // cp.async for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B1_; + + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + smem_write_stage_idx = Base::kStages - 1; + smem_read_stage_idx = 0; + + warp_mma1.transform(warp_transformed_frag_A1[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A1[0], warp_loaded_frag_B1[0]); + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for ( gemm_k_iterations_1 = Shape0::kN / Shape1::kK - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + + // Load warp-level tile from accumulator fragment + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > (-Base::kStages + 2) || warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + ++this->warp_tile_iterator_B1_; + + + if (warp_mma_k > 0) + warp_mma1.transform(warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + + + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum + ); + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform(warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + + } + + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..28fcc94f755df181b085a5d0019ee97b64e5d340 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile + // (concept::MmaTensorOpFragmentIterator) + typename FragmentIteratorA1_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// FragmentIterator to load Scale or Bias vector from threadblock fragment + typename FragmentIteratorA1ScaleBias_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bMmaPipelined : + public B2bMmaBase { +public: + + ///< Base class + using Base = B2bMmaBase; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; + using Policy0 = Policy0_; ///< Policy describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using FragmentIteratorA1 = FragmentIteratorA1_; ///< Iterates over intermediate accumulator tile + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using FragmentIteratorA1ScaleBias = + FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; + + using SmemIteratorB1 = SmemIteratorB1_; + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + static const bool PerChannelScale = (OutputOp::kScale == + epilogue::thread::ScaleType::OnlyAlphaPerChannelScaling); + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of Scale and Bias loaded from global memory + using FragmentA1ScaleBias = typename IteratorAccumulatorScaleBias::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + /// Warp Fragment of operand A1 loaded from accmulator tile + using WarpFragmentA1 = typename FragmentIteratorA1::Fragment; + /// Warp Fragment of operand A1 scale and bias loaded from threadblock fragment + using WarpFragmentA1ScaleBias = + typename FragmentIteratorA1ScaleBias::Fragment; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelined( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_B1_(shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + + // Compute warp location within threadblock tile by mapping the warp_id to three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + //These should stay the same across different GEMM layers + int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM; + + //These may change across different GEMM layers + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k; + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_A1_scale, ///< iterator over A1 operand scale vectors in global memory + IteratorAccumulatorScaleBias iterator_A1_bias, ///< iterator over A1 operand bias vectors in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations_0 <= 1); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations_0 <= 2); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + //2nd Gemm + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + + // + // Prologue + // + + FragmentA1ScaleBias tb_frag_A1_scale; + FragmentA1ScaleBias tb_frag_A1_bias; + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_scale_(tb_frag_A1_scale); + FragmentIteratorA1ScaleBias warp_tile_iterator_A1_bias_(tb_frag_A1_bias); + FragmentB1 tb_frag_B1; + + if(PerChannelScale) + tb_frag_A1_scale.clear(); + tb_frag_A1_bias.clear(); + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + if(PerChannelScale) + iterator_A1_scale.load(tb_frag_A1_scale); + iterator_A1_bias.load(tb_frag_A1_bias); + iterator_B1.load(tb_frag_B1); + + if(PerChannelScale) + ++iterator_A1_scale; + ++iterator_A1_bias; + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1ScaleBias warp_frag_A1_scale[2]; + WarpFragmentA1ScaleBias warp_frag_A1_bias[2]; + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[0]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[0]); + warp_tile_iterator_A1_.load(warp_frag_A1[0], warp_frag_A1_scale[0], + warp_frag_A1_bias[0], output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + + // Avoid reading out of bounds + iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); + + // + // Mainloop + // + + // Note: The main loop does not support Base::WarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + if(PerChannelScale) { + tb_frag_A1_scale.clear(); + iterator_A1_scale.load(tb_frag_A1_scale); + ++iterator_A1_scale; + } + tb_frag_A1_bias.clear(); + iterator_A1_bias.load(tb_frag_A1_bias); + ++iterator_A1_bias; + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + if(PerChannelScale) + warp_tile_iterator_A1_scale_.load(warp_frag_A1_scale[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_bias_.load(warp_frag_A1_bias[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], + warp_frag_A1_scale[(warp_mma_k + 1) % 2], + warp_frag_A1_bias[(warp_mma_k + 1) % 2], + output_op_0); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + if(PerChannelScale) + ++warp_tile_iterator_A1_scale_; + ++warp_tile_iterator_A1_bias_; + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + ++iterator_B1; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..b3754945b3c14ef563603850945914357d7a00b8 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -0,0 +1,552 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped Back-to-back fused GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_base_smem_accumulator.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape0_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA0_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA0_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Iterates over vectors of scale and bias vector in global memory + // (concept: VectorIterator) + typename IteratorAccumulatorScaleBias_, + /// Iterates over accumulator tile + typename FragmentIteratorAccumulator_, + /// Iterates over accumulator tile in shared memory + typename SmemIteratorD0_, + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Output operator for 1st Gemm(concept: epilogue::thread::LinearCombinationClamp, etc...) + typename OutputOp_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy0_, + /// Policy describing tuning details (concept: MmaPipelinedPolicy) + typename Policy1_, + /// Transformation applied to A0 operand + typename TransformA0_ = NumericArrayConverter< + typename SmemIteratorA0_::Element, + typename IteratorA0_::Element, + IteratorA0_::Fragment::kElements>, + /// + /// Transformation applied to B0 operand + typename TransformB0_ = NumericArrayConverter< + typename SmemIteratorB0_::Element, + typename IteratorB0_::Element, + IteratorB0_::Fragment::kElements>, + /// + /// Transformation applied to B1 operand + typename TransformB1_ = NumericArrayConverter< + typename SmemIteratorB1_::Element, + typename IteratorB1_::Element, + IteratorB1_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class B2bMmaPipelinedSmemAccumulator : + public B2bMmaBaseSmemAccumulator { +public: + + ///< Base class + using Base = B2bMmaBaseSmemAccumulator; + + using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; + using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; + using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory + using Policy0 = Policy0_; ///< Policy0 describing tuning details + + using SmemIteratorA0 = SmemIteratorA0_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorD0 = SmemIteratorD0_; ///< Iterates over accumulator tile in shared memory + + using FragmentIteratorAccumulator = FragmentIteratorAccumulator_; ///< Iterates over accumulator tile + + using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory + using Policy1 = Policy1_; ///< Policy1 describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory + + + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + + using OutputOp = OutputOp_; ///< Epilogue after 1st Gemm + + using TransformA0 = TransformA0_; + using TransformB0 = TransformB0_; + using TransformB1 = TransformB1_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA0 = typename IteratorA0::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB0 = typename IteratorB0::Fragment; + + /// Fragment of accumulator tile + using FragmentC0 = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + + /// Fragment of operand B loaded from global memory + using FragmentB1 = typename IteratorB1::Fragment; + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy0::Operator::ArchTag; + + /// Complex transform on A0 operand + static ComplexTransform const kTransformA0 = Operator0::kTransformA; + + /// Complex transform on B0 operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + + /// Complex transform on B1 operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + + /// Epilog in shared memory + using Epilogue0 = epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, ///< SmemTileIterator + FragmentIteratorAccumulator, ///< AccumulatorFragmentIterator + IteratorAccumulatorScaleBias, ///< ScaleBiasIterator + OutputOp>; ///< Output operator + + + +private: + + using WarpFragmentA0 = typename Operator0::FragmentA; + using WarpFragmentB0 = typename Operator0::FragmentB; + using WarpFragmentA1 = typename Operator1::FragmentA; + using WarpFragmentB1 = typename Operator1::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA0 smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + + /// Shared Memory Iterator to store accumulator tile + SmemIteratorD0 smem_iterator_D0_; + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + B2bMmaPipelinedSmemAccumulator( + typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), + smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), + smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_0 = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN); + int warp_idx_k_0 = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN); + + int warp_idx_m_0 = warp_idx_mn_0 % Base::WarpCount0::kM; + int warp_idx_n_0 = warp_idx_mn_0 / Base::WarpCount0::kM; + + int tile_offset_k_0 = Base::kWarpGemmIterations0 * warp_idx_k_0; + + int warp_idx_mn_1 = warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + int tile_offset_k_1 = Base::kWarpGemmIterations1 * warp_idx_k_1; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m_0, tile_offset_k_0}); + this->warp_tile_iterator_B0_.add_tile_offset({tile_offset_k_0, warp_idx_n_0}); + warp_tile_iterator_A1_.add_tile_offset({warp_idx_m_1, tile_offset_k_1}); + this->warp_tile_iterator_B1_.add_tile_offset({tile_offset_k_1, warp_idx_n_1}); + + // Add smem accumulator iterator warp offset + smem_iterator_D0_.add_tile_offset({ warp_idx_m_0 * SmemIteratorD0::TileIterations::kRow, + warp_idx_n_0 * SmemIteratorD0::TileIterations::kColumn}); + + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations_0, ///< number of iterations of the mainloop + FragmentC1 &accum, ///< destination accumulator tile + IteratorA0 iterator_A, ///< iterator over A operand in global memory + IteratorB0 iterator_B0, ///< iterator over B0 operand in global memory + IteratorAccumulatorScaleBias iterator_accum0_scale, ///< iterator over D0 scale vector in global memory + IteratorAccumulatorScaleBias iterator_accum0_bias, ///< iterator over D0 bias vector in global memory + IteratorB1 iterator_B1, ///< iterator over B1 operand in global memory + FragmentC0 const &src_accum, ///< source accumulator tile + OutputOp output_op_0, ///< epilogue operation after 1st Gemm + TransformA0 transform_A0 = TransformA0(), ///< transformation applied to A0 fragment + TransformB0 transform_B0 = TransformB0(), ///< transformation applied to B0 fragment + TransformB1 transform_B1 = TransformB1()) { ///< transformation applied to B1 fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + FragmentC0 accum0 = src_accum; + + FragmentA0 tb_frag_A; + FragmentB0 tb_frag_B0; + + tb_frag_A.clear(); + tb_frag_B0.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + + ++iterator_A; + ++iterator_B0; + + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA0 warp_frag_A0[2]; + WarpFragmentB0 warp_frag_B0[2]; + + this->warp_tile_iterator_A0_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + Operator0 warp_mma0; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations_0 <= 1); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A0(tb_frag_A)); + + this->smem_iterator_B0_.store(transform_B0(tb_frag_B0)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B0_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A0_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * Base::kWarpGemmIterations0, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0); + + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A0_; + ++this->warp_tile_iterator_B0_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B0.load(tb_frag_B0); + ++iterator_A; + ++iterator_B0; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations_0 <= 2); + iterator_B0.clear_mask(gemm_k_iterations_0 <= 2); + } + + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], + warp_frag_B0[warp_mma_k % 2], accum0); + } + } + + /// Epilogue for the first Implicit Gemm + Epilogue0 epilogue0; + + epilogue0(output_op_0, smem_iterator_D0_, accum0, iterator_accum0_scale, iterator_accum0_bias); + + __syncthreads(); + + //2nd Gemm + + // + // Prologue + // + + FragmentB1 tb_frag_B1; + + tb_frag_B1.clear(); + + // The last kblock is loaded in the prolog + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + ++this->smem_iterator_B1_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA1 warp_frag_A1[2]; + WarpFragmentB1 warp_frag_B1[2]; + + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + warp_tile_iterator_A1_.load(warp_frag_A1[0]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + Operator1 warp_mma1; + + smem_write_stage_idx = 1; + + int gemm_k_iterations_1 = Shape0::kN / Shape1::kK; + + // Avoid reading out of bounds + iterator_B1.clear_mask(gemm_k_iterations_1 <= 1); + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_PRAGMA_UNROLL + for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + + // Write fragments to shared memory + this->smem_iterator_B1_.store(transform_B1(tb_frag_B1)); + + __syncthreads(); + + ++this->smem_iterator_B1_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + } + + smem_write_stage_idx ^= 1; + + } + + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + + // skip warp tile loading for the last kgroup + if(gemm_k_iterations_1 > 1 || warp_mma_k < Base::kWarpGemmIterations1 - 1) + warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + + ++warp_tile_iterator_A1_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k == 0) { + + iterator_B1.load(tb_frag_B1); + + ++iterator_B1; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B1.clear_mask(gemm_k_iterations_1 <= 2); + } + + warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], + warp_frag_B1[warp_mma_k % 2], accum); + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..cbbc24a8f34ad9cbe64d39205a918b2bd3bfd040 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h @@ -0,0 +1,584 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/transform/threadblock/predicated_vector_access_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "cutlass/transform/warp/vector_fragment_iterator.h" + +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" + +#include "threadblock/b2b_mma_pipelined.h" +#include "threadblock/b2b_mma_multistage.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Staging the accumulators in shared memory. + bool SmemAccumulator = false> +struct DefaultB2bMma; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output with 2-stage pipeline +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; + + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output for multi-stage +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA0 = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB0 = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; + + // Use fragment iterator for A operand + using AccumulatorLayout = cutlass::layout::ColumnMajor; + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, InstructionShape, EpilogueOutputOp>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using AccessTypeB1 = cutlass::Array; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with 2-stage pipeline +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; + + // Use fragment iterator for A1 operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp>; + + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelined< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with multi-stage +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; + + // Use fragment iterator for A1 operand + using AccumulatorLayout = cutlass::layout::RowMajor; //AccumulatorsInRowMajor = true + using FragmentIteratorA1 = + cutlass::gemm::warp::MmaTensorOpFragmentIterator< + cutlass::MatrixShape, //warp shape + cutlass::MatrixShape, //accumulator shape + MmaCore1::Shape::kK, //kBlocksColumn + ElementAccumulator, ElementA, AccumulatorLayout, + InstructionShape, EpilogueOutputOp>; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Warp-level iterators to load scale and bias vectors + using FragmentIteratorA1ScaleBias = cutlass::transform::warp::VectorFragmentIterator< + MatrixShape<1, IteratorAccumulatorScaleBias::Fragment::kElements>, ElementScaleBias, + LayoutScaleBias, InstructionShape, kElementsPerAccess>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; + + + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistage< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + typename MmaCore1::Shape, FragmentIteratorA1, + IteratorAccumulatorScaleBias, FragmentIteratorA1ScaleBias, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h new file mode 100644 index 0000000000000000000000000000000000000000..a848f5c43995682d3e21d70a5cdb12b3855ae4f9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h @@ -0,0 +1,605 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" + +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +#include "threadblock/b2b_mma_pipelined_smem_accumulator.h" +#include "threadblock/b2b_mma_multistage_smem_accumulator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output with 2-stage pipeline +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, + arch::OpClassTensorOp, 2, Operator>; + + // Define iterators over tiles from the A operand + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, typename MmaCore0::IteratorThreadMapA, kAlignmentA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore0::IteratorThreadMapB, kAlignmentB>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB, kAlignmentB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for row-major output for multi-stage +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp> +struct DefaultB2bMma { + + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * kAlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * kAlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, Operator, false, CacheOpA, CacheOpB>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA0 = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA0>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB0 = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB0>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using AccessTypeB1 = cutlass::Array; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB1>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 2; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::RowMajor, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; + +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization for column-major-interleaved output with 2-stage pipeline +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, 2, Operator, EpilogueOutputOp, true, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, 2, Operator, + true>; + + static_assert(kAlignmentA == 128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + static_assert(kAlignmentB ==128 / sizeof_bits::value, + "Alignment must match thread data map's vector length"); + + // Define iterators over tiles from the A operand + using IteratorA0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementA, + LayoutA, 1, typename MmaCore0::IteratorThreadMapA>; + + // Define iterators over tiles from the B operand + using IteratorB0 = cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, ElementB, + LayoutB, 0, typename MmaCore0::IteratorThreadMapB>; + + // Define iterators over tiles from the B operand + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, typename MmaCore1::IteratorThreadMapB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; //For interleaved layout + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; + + // Define the threadblock-scoped pipelined matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + IteratorB0, typename MmaCore0::SmemIteratorB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy>; + +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Specialization for column-major-interleaved output with multi-stage +/// Accumulator will be staged in shared memory. +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Tag indicating architecture to tune for + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the multistage mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + /// Epilogue output operator + typename EpilogueOutputOp, + /// Number of Interleaved K + int InterleavedK> +struct DefaultB2bMma, OperatorClass, ArchTag, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp, true, true> { + // Define the MmaCore components + using MmaCore0 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape0, WarpShape0, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + using MmaCore1 = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape1, WarpShape1, InstructionShape, ElementA, LayoutA, + ElementB, LayoutB, ElementAccumulator, + layout::ColumnMajorInterleaved, OperatorClass, Stages, + Operator, true>; + + // Define iterators over tiles from the A operand + using ThreadMapA0 = typename MmaCore0::IteratorThreadMapA; + using AccessTypeA = cutlass::Array; + using IteratorA0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementA, LayoutA, 1, ThreadMapA0, AccessTypeA>; + + // Define iterators over tiles from the B operand + using ThreadMapB0 = typename MmaCore0::IteratorThreadMapB; + using AccessTypeB = cutlass::Array; + using IteratorB0 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB0, AccessTypeB>; + + // Define iterators over tiles from the B operand + using ThreadMapB1 = typename MmaCore1::IteratorThreadMapB; + using IteratorB1 = + cutlass::transform::threadblock::PredicatedTileAccessIterator< + cutlass::MatrixShape, + ElementB, LayoutB, 0, ThreadMapB1, AccessTypeB>; + + // Warp-level GEMM components + using WarpMmaTensorOp0 = typename MmaCore0::MmaTensorOp; + using WarpMmaTensorOp1 = typename MmaCore1::MmaTensorOp; + using MmaPolicy0 = typename MmaCore0::MmaPolicy; + using MmaPolicy1 = typename MmaCore1::MmaPolicy; + + // Use fragment iterator for the accumulator + using SmemAccumulatorLayout = cutlass::layout::ColumnMajorInterleaved<16>; + using FragmentIteratorAccumulator = cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape0, InstructionShape, + ElementAccumulator, + typename WarpMmaTensorOp0::Policy::Operator::FragmentC, + SmemAccumulatorLayout + >; + + /// Define iterators over tiles from scale/bias vectors + using ElementScaleBias = typename EpilogueOutputOp::ElementCompute; + using LayoutScaleBias = layout::RowMajor; //vector layout doesn't really matter + static int const kElementsPerAccess = 4; + using IteratorAccumulatorScaleBias = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + cutlass::MatrixShape, + cutlass::MatrixShape, + ElementScaleBias, LayoutScaleBias, kElementsPerAccess> + >; + + // Store Accumulator tiles to Shared Memory + using SmemIteratorD0 = + cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape0, + InstructionShape, + typename EpilogueOutputOp::ElementOutput, + SmemAccumulatorLayout + >; + + static int const kThreadCount = 32; + // load warp tile from Shared Memory accumulator + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, + ElementA, SmemAccumulatorLayout, + MatrixShape, + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; + + + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< + typename MmaCore0::Shape, IteratorA0, typename MmaCore0::SmemIteratorA, + MmaCore0::kCacheOpA, + IteratorB0, typename MmaCore0::SmemIteratorB, MmaCore0::kCacheOpB, + IteratorAccumulatorScaleBias, + FragmentIteratorAccumulator, SmemIteratorD0, + typename MmaCore1::Shape, WarpIteratorA1, + IteratorB1, typename MmaCore1::SmemIteratorB, MmaCore1::kCacheOpB, + ElementAccumulator, layout::ColumnMajorInterleaved, + EpilogueOutputOp, + typename MmaCore0::MmaPolicy, typename MmaCore1::MmaPolicy, Stages>; +}; + +//////////////////////////////////////////////////////////////////////////////// + + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h new file mode 100644 index 0000000000000000000000000000000000000000..e033138057e6ddc5b316822e18a8f473c2f59913 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several threadblock-swizzling functions for grouped kernels +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct GroupedThreadblockSwizzleBase {}; + +/// Helper for determining if a swizzling function is specialized for grouped operation +template +struct IsGroupedSwizzle { + static bool const value = cutlass::platform::is_base_of::value; +}; + +} // namespace detail + +/// Swizzling function for grouped kernels +template +struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { + + using ProblemVisitor = ProblemVisitor_; + ProblemVisitor problem_visitor; + + CUTLASS_HOST_DEVICE + GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params, + typename ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : problem_visitor(params, shared_storage, block_idx) {} + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int /*log_tile*/) const { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + return GemmCoord(int(threadblock_idx / grid_shape.n()), + int(threadblock_idx % grid_shape.n()), + 0); + } + + /// Dummy method to satisfy API for threadblock swizzling functions + CUTLASS_HOST_DEVICE + static int get_log_tile(GemmCoord /*tiled_shape*/) { + return 0; + } +}; + +template < + typename ThreadblockShape, + typename LayoutC, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + int PrefetchTileCount = 128, + int ThreadCount = PrefetchTileCount> +struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + platform::is_same::value + > + > { + using Base = GroupedThreadblockSwizzle::value>>; + + CUTLASS_HOST_DEVICE + B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : Base(params, shared_storage, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..fc8f96c1fd8b0b485d48e2873afb52a0f8c2516b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h @@ -0,0 +1,536 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + using ElementNorm = typename EpilogueVisitor::ElementNorm; + using ElementSum = typename EpilogueVisitor::ElementSum; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + int batch_count; + + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm), + batch_count(1) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + int batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + ElementNorm *ptr_Max_, + ElementSum *ptr_Sum_, + int64_t batch_stride_A_, + int64_t batch_stride_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + batch_count(batch_count_), + ref_A(ref_A_), + ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Max(ptr_Max_), + ptr_Sum(ptr_Sum_), + batch_stride_A(batch_stride_A_), + batch_stride_B(batch_stride_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; + + GemmUniversalMode mode; + int batch_count; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + ElementC * ptr_C; + ElementC * ptr_D; + + ElementNorm * ptr_Max; + ElementSum * ptr_Sum; + + int64_t batch_stride_A; + int64_t batch_stride_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + params_C(0), + params_D(0), + batch_count(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), + batch_stride_A(0), + batch_stride_B(0) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + mode(args.mode), + batch_count(args.batch_count), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), + batch_stride_A(args.batch_stride_A), + batch_stride_B(args.batch_stride_B), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.batch_count); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(ceil_div(args.problem_size.k(), args.batch_count), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + #define SPLIT_K_ENABLED 1 + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + + #if SPLIT_K_ENABLED + // + // Fetch pointers based on mode. + // + if (params.mode == GemmUniversalMode::kGemm || + params.mode == GemmUniversalMode::kGemmSplitKParallel) { + + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B += threadblock_tile_offset.k() * params.batch_stride_B; + } + else if (params.mode == GemmUniversalMode::kArray) { + ptr_A = static_cast(params.ptr_A)[threadblock_tile_offset.k()]; + ptr_B = static_cast(params.ptr_B)[threadblock_tile_offset.k()]; + } + #endif + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + params.params_C, + params.params_D, + params.ptr_C, + params.ptr_D, + params.ptr_Max, + params.ptr_Sum, + threadblock_offset, + blockIdx.y *params.problem_size.m() ); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h new file mode 100644 index 0000000000000000000000000000000000000000..31b2b76940cd06b6ec54a647acc7943064cefe4b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/35_gemm_softmax/gemm_with_softmax.h @@ -0,0 +1,663 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/reduction/kernel/reduce_softmax_final.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Kernel computes partial reduction +// +// +// 2. Sum[m, n'] = sum_n(exp(D[m, n] - N[m, 0])) +// +template < + typename ElementD_, + typename ElementNorm_, + typename ElementSum_, + typename ElementSoft_, + typename ElementSoftmaxCompute_, + int Alignment, + typename ApplyShape_ = MatrixShape<1, 1024> +> +class ApplySoftmax { +public: + + using ElementD = ElementD_; + using ElementNorm = ElementNorm_; + using ElementSum = ElementSum_; + using ElementSoft = ElementSoft_; + using ElementSoftmaxCompute = ElementSoftmaxCompute_; + + static int const kAlignment = Alignment; + using ApplyShape = ApplyShape_; + + using Layout = cutlass::layout::RowMajor; + + using TensorRefD = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; + using TensorRefSoft = TensorRef; + + using FragmentSoftmax = Array; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Softmax matrices + int batch_count; ///< Batch count + TensorRefD ref_D; ///< D matrix computed by GEMM+Max (input) + TensorRefN ref_N; ///< Norm tensor (input) + TensorRefSum ref_S; ///< Sum tensor (input) + TensorRefSoft ref_Soft; ///< Softmax tensor (output) + int64_t batch_stride_D; ///< Batch stride for D tensor + int64_t batch_stride_N; ///< Batch stride for N tensor + int64_t batch_stride_S; ///< Batch stride for S tensor + int64_t batch_stride_Soft; ///< Batch stride for softmax tensor + + // + // Methods + // + Arguments(): + batch_count(1), + batch_stride_D(0), + batch_stride_N(0), + batch_stride_S(0), + batch_stride_Soft(0) + { } + + Arguments( + MatrixCoord extent_, ///< Extent of D and Softmax matrices + int batch_count_, ///< Batch count + TensorRefD ref_D_, ///< D matrix computed by GEMM+PartialReduce + TensorRefN ref_N_, ///< Output parameter for N + TensorRefSum ref_S_, ///< Output parameter for N + TensorRefSoft ref_Soft_, ///< Softmax + int64_t batch_stride_D_ = 0, + int64_t batch_stride_N_ = 0, + int64_t batch_stride_S_ = 0, + int64_t batch_stride_Soft_ = 0 + ): + extent(extent_), + batch_count(batch_count_), + ref_D(ref_D_), + ref_N(ref_N_), + ref_S(ref_S_), + ref_Soft(ref_Soft_), + batch_stride_D(batch_stride_D_), + batch_stride_N(batch_stride_N_), + batch_stride_S(batch_stride_S_), + batch_stride_Soft(batch_stride_Soft_) + { + + } + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + + // + // SharedStorage + // + + struct SharedStorage { + + }; + +private: + +public: + + CUTLASS_DEVICE + ApplySoftmax() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + apply(params, shared_storage); + } + +private: + + + /// Compute Softmax + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + using AccessTypeD = AlignedArray; + + int block_batch = blockIdx.z; + int block_m = blockIdx.x * ApplyShape::kRow; + int block_n = 0; + + int thread_m = threadIdx.y; + int thread_n = threadIdx.x * kAlignment; + + int idx_m = block_m + thread_m; + int idx_n = block_n + thread_n; + + int batch_offset_norm = block_batch * params.args.batch_stride_N; + int batch_offset_sum = block_batch * params.args.batch_stride_S; + + // Kill off thread if it is outside the row boundary + if (params.args.extent.row() <= idx_m) { + return; + } + + // + // Setup pointers to load D again + // + + using AccessTypeD = AlignedArray; + using AccessTypeSoft = AlignedArray; + using FragmentSoft = Array; + using ConvertSoftCompute = cutlass::NumericArrayConverter; + using ConvertSoftOutput = cutlass::NumericArrayConverter; + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + ConvertSoftCompute convert_soft_compute; + ConvertSoftOutput convert_soft_output; + + Minus minus; + Mul mul; + Exp exponential; + + using ConvertSum = cutlass::NumericConverter; + using ConvertNorm = cutlass::NumericConverter; + + ConvertSum convert_sum; + ConvertNorm convert_norm; + + AccessTypeD *access_d = reinterpret_cast( + params.args.ref_D.data() + + params.args.batch_stride_D * block_batch + + params.args.ref_D.layout()({idx_m, idx_n})); + + AccessTypeSoft *access_soft = reinterpret_cast( + params.args.ref_Soft.data() + + params.args.batch_stride_Soft * block_batch + + params.args.ref_Soft.layout()({idx_m, idx_n})); + + ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; + ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; + + // + // Loop + // + CUTLASS_PRAGMA_UNROLL + for ( + int idx = 0; + idx < params.args.extent.column(); + idx += ApplyShape::kColumn * kAlignment) { + + if (idx_n < params.args.extent.column()) { + AccessTypeD fetch; + arch::global_load(fetch, access_d, true); + + FragmentSoftmax result = mul(exponential(minus(convert_soft_compute(fetch), convert_norm(norm))), convert_sum(inv_sum)); + FragmentSoft soft = convert_soft_output(result); + + arch::global_store(soft, access_soft, true); + } + + access_d += ApplyShape::kColumn; + access_soft += ApplyShape::kColumn; + idx_n += ApplyShape::kColumn * kAlignment; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementA_, + typename LayoutA_, + typename ElementB_, + typename LayoutB_, + typename ElementC_, + typename ElementCompute_, + typename OperatorClass_, + typename ArchTag_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + typename EpilogueFunctorOp_, + int kStages_, + typename ApplyShape_ = MatrixShape<1, 1024>, + int AlignmentA_ = 128 / cutlass::sizeof_bits::value, + int AlignmentB_ = 128 / cutlass::sizeof_bits::value, + int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, + typename ElementNorm_ = float, + typename ElementSum_ = float, + typename ElementSoftmax_ = ElementC_ +> +class GemmSoftmax { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + using ElementA = ElementA_; + using ElementB = ElementB_; + using ElementC = ElementC_; + using ElementCompute = ElementCompute_; + using ElementSum = ElementSum_; + using ElementSoft = ElementSoftmax_; + using ElementSoftmaxCompute = float; + + using LayoutA = LayoutA_; + using LayoutB = LayoutB_; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + using ElementNorm = ElementNorm_; + + using ApplyShape = ApplyShape_; + + // These are mandatory layouts. + using LayoutC = cutlass::layout::RowMajor; + using LayoutN = cutlass::layout::RowMajor; + using LayoutS = cutlass::layout::RowMajor; + using LayoutSoft = cutlass::layout::RowMajor; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorRefN = TensorRef; + using TensorRefSum = TensorRef; + using TensorRefSoft = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + + static int const kStages = kStages_; + static int const AlignmentA = AlignmentA_; + static int const AlignmentB = AlignmentB_; + static int const AlignmentSoftmax = AlignmentSoftmax_; + + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // basic GEMM kernel + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + AlignmentA, + ElementB, + LayoutB, + AlignmentB, + ElementC, + LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + ThreadblockSwizzle, + kStages, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementA, ElementB, ElementC, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + ElementCompute, + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + EpilogueFunctorOp + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmKernel = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + ThreadblockSwizzle + >; + + // Softmax kernel + using SoftmaxApplyKernel = kernel::ApplySoftmax< + ElementC, + ElementNorm, + ElementSum, + ElementSoft, + ElementSoftmaxCompute, + AlignmentSoftmax, + ApplyShape + >; + + using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< + ElementNorm, + ElementSum, + ElementSoftmaxCompute, + ThreadblockShape + >; + +public: + + /// Arguments class + struct Arguments { + + typename GemmKernel::Arguments gemm; + typename SoftmaxApplyKernel::Arguments softmax; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size, + int32_t batch_count_, + TensorRefA ref_A_, + TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + typename EpilogueFunctorOp::Params linear_scaling, + TensorRefN ref_N_, + TensorRefSum ref_S_, + TensorRefSoft ref_Softmax_, + int64_t batch_stride_A_ = 0, + int64_t batch_stride_B_ = 0, + int64_t batch_stride_C_ = 0, + int64_t batch_stride_D_ = 0, + int64_t batch_stride_Max_ = 0, + int64_t batch_stride_Sum_ = 0, + int64_t batch_stride_Softmax_ = 0 + ): + gemm( + cutlass::gemm::GemmUniversalMode::kBatched, + problem_size, + batch_count_, + ref_A_, + ref_B_, + ref_C_, + ref_D_, + ref_N_.data(), + ref_S_.data(), + batch_stride_A_, + batch_stride_B_, + typename EpilogueVisitor::Arguments( + linear_scaling, + batch_stride_C_, + batch_stride_D_, + batch_stride_Max_, + batch_stride_Sum_ + ) + ), + reduction( + problem_size, + ref_N_.data(), + ref_S_.data(), + batch_stride_Max_, + batch_stride_Sum_ + ), + softmax( + MatrixCoord(problem_size.m(), problem_size.n()), + batch_count_, + ref_D_, + ref_N_, + ref_S_, + ref_Softmax_, + batch_stride_D_, + batch_stride_Max_, + batch_stride_Sum_, + batch_stride_Softmax_ + ), + extend(problem_size) + { + + } + }; + + struct Params { + + typename GemmKernel::Params gemm; + typename SoftmaxApplyKernel::Params softmax; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm(args.gemm), + reduction(args.reduction), + softmax(args.softmax), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + +public: + + /// Ctor + GemmSoftmax() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + max kernel + // + + dim3 gemm_grid = ThreadblockSwizzle().get_grid_shape(params_.gemm.grid_tiled_shape); + dim3 gemm_block(GemmKernel::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + + cudaError_t result; + + if (gemm_smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + gemm_smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_.gemm); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + + // + // Launch the ApplyFinalReductionKernel + // + + int thread_per_block = 128; + int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); + dim3 final_reduction_block(thread_per_block); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the SoftmaxApplyKernel + // + + dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); + + int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; + int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; + + dim3 apply_grid( + (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, + (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, + params_.softmax.args.batch_count); + + Kernel<<< + apply_grid, apply_block, sizeof(typename SoftmaxApplyKernel::SharedStorage), stream + >>>(params_.softmax); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..5b36be05853718518851b4967f769b72697bccad --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h @@ -0,0 +1,444 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized layernorm partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + + TensorRefA ref_A; + TensorRefB ref_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + GemmUniversalMode mode; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + mode(args.mode), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(args.problem_size.k(), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + threadblock_offset); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h new file mode 100644 index 0000000000000000000000000000000000000000..813f75604ba2bc3a3ce145e1e490d6a607de1981 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -0,0 +1,1066 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A file contains all functioning classes needed by GemmLayernorm. + + GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + + lightweight full reduction kernel (ApplyFinalReduction) + + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" +#include "helper.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementOutput, + typename ThreadblockShape_, + bool IsShiftedVariance_ = false +> +class ApplyFinalReduction { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + using ThreadblockShape = ThreadblockShape_; + + // Pre-processing has ensured the layout equivalent to RowMajor + using Layout = cutlass::layout::RowMajor; + + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Layernorm matrices + TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) + TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) + ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer + + // + // Methods + // + Arguments(){ } + + Arguments( + MatrixCoord extent_, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + ElementOutput *ptr_Shifted_K_ + ): + extent(extent_), + ref_Variance(ref_Variance_), + ref_Mean(ref_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplyFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Partial reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + + int block_n = blockIdx.x * blockDim.x; + + int thread_n = threadIdx.x; + + int idx_n = block_n + thread_n; + + if (idx_n >= params.args.extent.row()) { + return; + } + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + using ConvertVariance = cutlass::NumericConverter; + using ConvertMean = cutlass::NumericConverter; + + using ConvertShiftK = cutlass::NumericConverter; + + ConvertVariance convert_variance; + ConvertMean convert_mean; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; + ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; + + ElementVariance *access_square_bak = access_square; + ElementMean *access_mean_bak = access_mean; + + ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); + ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); + ElementVariance fetch_square; + ElementMean fetch_mean; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_square, access_square, true); + arch::global_load(fetch_mean, access_mean, true); + frag_element_sum += convert_mean(fetch_mean); + frag_square_sum += convert_variance(fetch_square); + access_square += params.args.extent.row(); + access_mean += params.args.extent.row(); + } + + ElementLayernormCompute mean = frag_element_sum; + ElementLayernormCompute square_mean = frag_square_sum; + + ElementLayernormCompute variance; + + if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { + ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; + ElementOutput fetch_shift_k; + ConvertShiftK convert_shift_k; + arch::global_load(fetch_shift_k, access_shift_k, true); + ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); + }else{ + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); + } + + mean = -mean * variance; + + access_square = access_square_bak; + access_mean = access_mean_bak; + + access_square[0] = convert_variance_output(variance); + access_mean[0] = convert_mean_output(mean); + + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename AccumulatorTile_, + typename ElementAccumulator_, + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementwiseFunctor_, + bool IsShiftedVariance_ = false +> +class EpilogueVisitorLayerNorm { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + + using AccumulatorTile = AccumulatorTile_; + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; + + static int const kThreads = OutputTileIterator::ThreadMap::kThreads; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + using ElementOutput = typename OutputTileIterator::Element; + + static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; + + /// Array type used in Shift-K Layernorm + static int const kRowAccessCount = kIterations * kRowIterations; + + using ConvertedShiftFragment = Array; + + // Conducts manual transpose externally (already supported) for column major + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementAccumulator = ElementAccumulator_; + + using AccumulatorFragment = Array; + using LayernormFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; + static int const kThreadsInColumn = kThreads / kThreadsPerRow; + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + TensorRefD ref_C; + TensorRefD ref_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + Arguments(): + ptr_Variance(nullptr), + ptr_Mean(nullptr), + ptr_Shifted_K(nullptr) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + TensorRefD ref_C_, + TensorRefD ref_D_, + ElementVariance *ptr_Variance, + ElementMean *ptr_Mean_, + ElementOutput *ptr_Shifted_K_ = nullptr + ): + elementwise(elementwise_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Variance(ptr_Variance), + ptr_Mean(ptr_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + typename OutputTileIterator::Params params_C; + typename OutputTileIterator::Params params_D; + typename OutputTileIterator::Element *ptr_C; + typename OutputTileIterator::Element *ptr_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(): + ptr_D(nullptr), + ptr_Variance(nullptr), + ptr_Mean(nullptr) + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Variance(args.ptr_Variance), + ptr_Mean(args.ptr_Mean), + ptr_Shifted_K(args.ptr_Shifted_K) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + ConvertedShiftFragment shift_k_frag_; + + ElementLayernormCompute accum_sum_square_; + ElementLayernormCompute accum_sum_element_; + + MatrixCoord thread_offset_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorLayerNorm( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord const &problem_size0, ///< Problem size of the output + int thread_idx, ///< Thread index within the threadblock + int warp_idx, ///< Warp index within the threadblock + int lane_idx, ///< Lane index within the warp + MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) + ): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size0), + elementwise_(params.elementwise), + iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), + iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + // If shift-K feature is enabled, we load shift-k fragment + // at the very beginning of an epilogue + if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { + shift_k_frag_.clear(); + int thread_offset_row_base = iterator_D_.thread_start_row(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { + int step_offset = iter_idx * OutputTileIterator::Shape::kRow; + CUTLASS_PRAGMA_UNROLL + for (int rid = 0; rid < kRowIterations; ++rid) { + int row_step_offset = rid * kDeltaRow; + int row_offset = thread_offset_row_base + step_offset + row_step_offset; + bool is_load = (row_offset < extent_.row()); + shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); + } + + } + + } + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + fragment_C_.clear(); + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + [[maybe_unused]] Minus minus; + [[maybe_unused]] Mul mul; + [[maybe_unused]] Exp exponential; + + LayernormFragment result; + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + + ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); + + // Fragment is cleared for non-reachable columns so no need to check against column guard + accum_sum_element_ = element_sum_accumulator_(result); + + // Square sum is different. Non-reachable columns should've been computed for shift-k + // Otherwise we will incorrectly have some extra k^2 added into square sum. + if (column_guard) { + accum_sum_square_ = (kIsShiftedVariance) ? \ + square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ + square_sum_accumulator_(result); + } + else { + accum_sum_square_ = ElementLayernormCompute(0); + } + + accum_sum_element_ *= inv_scalar; + accum_sum_square_ *= inv_scalar; + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { + accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); + accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); + int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); + + ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; + ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; + + arch::global_store( + convert_variance_output(accum_sum_square_), + (void *)curr_ptr_sum_square, + is_write_thread); + + arch::global_store( + convert_mean_output(accum_sum_element_), + (void *)curr_ptr_element_sum, + is_write_thread); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { + using ConvertShiftK = cutlass::NumericConverter; + ConvertShiftK convert_shift_k; + ElementOutput shift_k_val; + + // Computes the address to load shift_k element + ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; + // Conditionally loads from global memory + arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); + // Converts data type to return + ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); + + return converted_shift_k_val; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i]; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i] - shift_k_val; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + sum_ += accum[i]; + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementInputA0_, + typename LayoutInputA0_, + typename ElementInputB0_, + typename LayoutInputB0_, + typename ElementOutput_, + typename LayoutOutput_, + typename ElementCompute_, + typename EpilogueFunctorOp_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + int Stages0, + int Stages1, + bool IsShiftedVariance_ = false +> +class GemmLayernorm { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + static bool const kInternalTranspose = cutlass::platform::is_same::value; + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // These is mandatory layout. + using LayoutInputScaleBias = cutlass::layout::RowMajor; + + // These are mandatory data types. + using ElementLayernormCompute = float; + using ElementInputScaleBias = cutlass::half_t; + + // These are mandatory params required by mainloop fusion + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + // These are mandatory layouts and data types + // that are inheritated from pre-defined params + + using LayoutSumSqr = LayoutInputScaleBias; + using LayoutSum = LayoutInputScaleBias; + + using ElementMean = ElementInputScaleBias; + using ElementVariance = ElementInputScaleBias; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using LayoutInputA0 = LayoutInputA0_; + using LayoutInputB0 = LayoutInputB0_; + using LayoutInputA1 = LayoutOutput_; + using LayoutInputB1 = LayoutOutput_; + using LayoutOutputC0 = LayoutOutput_; + using LayoutOutputC1 = LayoutOutput_; + + using ElementInputA0 = ElementInputA0_; + using ElementInputB0 = ElementInputB0_; + using ElementOutputC0 = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementInputB1 = ElementInputB0_; + + using ElementInputA1 = ElementOutputC0; + using ElementOutputC1 = ElementOutputC0; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + static int const kStages0 = Stages0; + static int const kStages1 = Stages1; + + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using MapArguments = cutlass::gemm::kernel::detail::MapArguments< + ElementInputA0, + LayoutInputA0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementInputB0, + LayoutInputB0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + LayoutOutputC0, + kInternalTranspose + >; + + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementOutputC0, + typename MapArguments::LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages0, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, + ElementCompute, + ElementVariance, + ElementMean, + ElementLayernormCompute, + EpilogueFunctorOp, + kIsShiftedVariance + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + SwizzleThreadBlock + >; + + using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + ElementVariance, + ElementMean, + ElementLayernormCompute, + ElementOutputC0, + ThreadblockShape, + kIsShiftedVariance + >; + +using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< + ElementInputA1, LayoutInputA1, + ElementInputB1, LayoutInputB1, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutputC1, LayoutOutputC1, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages1 +>; + +public: + + /// Arguments class + struct Arguments { + + typename GemmEpilogueFusion::Arguments gemm0; + typename GemmMainloopFusion::Arguments gemm1; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size0, + cutlass::gemm::GemmCoord problem_size1, + ElementInputA0 * ptr_A, + ElementInputB0 * ptr_B, + ElementOutputC0 * ptr_C, + ElementOutputC0 * ptr_D, + ElementOutputC0 * ptr_E, + ElementOutputC0 * ptr_O, + int64_t ldm_A, + int64_t ldm_B, + int64_t ldm_C, + int64_t ldm_D, + int64_t ldm_E, + int64_t ldm_O, + typename EpilogueFunctorOp::Params linear_scaling, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + TensorVariance ref_Gamma_, + TensorMean ref_Beta_, + ElementOutputC0 *ptr_Shifted_K = nullptr + ): + gemm0( + cutlass::gemm::GemmUniversalMode::kGemm, + {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ + problem_size0.k()}, + {kInternalTranspose ? ptr_B : ptr_A, \ + kInternalTranspose ? ldm_B : ldm_A}, + {kInternalTranspose ? ptr_A : ptr_B, \ + kInternalTranspose ? ldm_A : ldm_B}, + typename EpilogueVisitor::Arguments( + linear_scaling, + {ptr_C, ldm_C}, + {ptr_D, ldm_D}, + ref_Variance_.data(), + ref_Mean_.data(), + ptr_Shifted_K + ) + ), + reduction( + MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n()), + ref_Variance_, + ref_Mean_, + ptr_Shifted_K + ), + gemm1( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size1, + 1, + linear_scaling, + kInternalTranspose ? ptr_E : ptr_D, + kInternalTranspose ? ptr_D : ptr_E, + ref_Variance_.data(), + ref_Mean_.data(), + ref_Gamma_.data(), + ref_Beta_.data(), + ptr_O, + ptr_O, + problem_size1.m() * problem_size1.k(), + problem_size1.n() * problem_size1.k(), + problem_size1.n(), + problem_size1.n(), + problem_size1.k(), + problem_size1.k(), + problem_size1.m() * problem_size1.n(), + problem_size1.m() * problem_size1.n(), + kInternalTranspose ? ldm_E : ldm_D, + kInternalTranspose ? ldm_D : ldm_D, + ref_Variance_.layout().stride(0), + ref_Mean_.layout().stride(0), + ref_Gamma_.layout().stride(0), + ref_Beta_.layout().stride(0), + ldm_O, + ldm_O + ), + extend(problem_size0) + { + + } + }; + + struct Params { + + typename GemmEpilogueFusion::Params gemm0; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm0(args.gemm0), + reduction(args.reduction), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + GemmMainloopFusion gemm_fusion_op; + +public: + + /// Ctor + GemmLayernorm() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + cutlass::Status status; + size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_fusion_op.can_implement(args.gemm1); + CUTLASS_CHECK(status); + + status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + layernorm kernel + // + + dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); + dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); + + cutlass::Kernel<<>>(params_.gemm0); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the ApplyFinalReductionKernel + // + + // always performs reduction from leading dimension + int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); + int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); + + int thread_per_block = 128; + int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_block(thread_per_block); + dim3 final_reduction_grid(block_per_row); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the GEMM + mainloop fusion kernel + // + + cutlass::Status status = gemm_fusion_op(); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h new file mode 100644 index 0000000000000000000000000000000000000000..d4d9ed316667496d03fca7cbe460f413cba50290 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/layouts.h @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines additional layout functions used in Permute GEMM example to simplify + computing reference permutations of 4/5D tensors when source data is column-major. +*/ +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D CWHN tensors. +class TensorCWHN { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, hn, whn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN( + typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c ///< number of elements between adjacent W coordinates + ): + stride_(make_Coord(stride_h, stride_w, stride_c)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorCWHN packed(TensorCoord const &extent) { + return TensorCWHN( + make_Coord( + extent.n(), + extent.h() * extent.n(), + extent.w() * extent.h() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.h()) + + LongIndex(stride_[1] * coord.w()) + + LongIndex(stride_[2] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.h() * stride_[0] > stride_[1]) + || (extent.w() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.c() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNHCW { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [w, cw, hcw] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW( + typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_c, stride_h, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNHCW(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNHCW packed(TensorCoord const &extent) { + return TensorNHCW( + make_Coord( + extent.w(), + extent.c() * extent.w(), + extent.h() * extent.c() * extent.w() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.w() + + LongIndex(stride_[0] * coord.c()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.w() > stride_[0]) + || (extent.c() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNCWH { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [h, wh, cwh] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH( + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_w, stride_c, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNCWH(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNCWH packed(TensorCoord const &extent) { + return TensorNCWH( + make_Coord( + extent.h(), + extent.w() * extent.h(), + extent.c() * extent.w() * extent.h() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.h() + + LongIndex(stride_[0] * coord.w()) + + LongIndex(stride_[1] * coord.c()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.h() > stride_[0]) + || (extent.w() * stride_[0] > stride_[1]) + || (extent.c() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 5-D CWHDN tensors. +class TensorCWHDN { +public: + /// Logical rank of tensor + static int const kRank = 5; + + /// Rank of stride vector + static int const kStrideRank = 4; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, d, h, w, c) + using TensorCoord = Tensor5DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, dn, hdn, whdn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN( + typename Stride::Index n, + typename Stride::Index dn, + typename Stride::Index hdn, + typename Stride::Index whdn): + stride_(make_Coord(n, dn, hdn, whdn)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHDN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2]), + static_cast(stride[3])) + ) { } + + /// Helper returns a layout to a tightly packed CWHDN tensor. + CUTLASS_HOST_DEVICE + static TensorCWHDN packed(TensorCoord const &extent) { + return TensorCWHDN( + make_Coord( + extent.n(), + extent.d() * extent.n(), + extent.h() * extent.d() * extent.n(), + extent.w() * extent.h() * extent.d() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.d()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.w()) + + LongIndex(stride_[3] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.d() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2]) + || (extent.w() * stride_[2] > stride_[3])) { + assert(0); + } + return extent.c() * stride_[3]; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h new file mode 100644 index 0000000000000000000000000000000000000000..6baf635a8d5420ec736ea01eceba672b9df2e954 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/39_gemm_permute/permute_info.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Contains additional metadata about layout permute functions used in the example. +*/ + +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/permute.h" + +/// Additional permutation metadata to facilitate testing/printing +template +struct PermuteInfo; + +/// Specialization for default case (no permute). Other specializations must follow this template. +template<> +struct PermuteInfo { + + /// Whether this is a BMM or GEMM permutation (NoPermute can actually be either) + static bool constexpr kBatched = false; + + /// Minimal divisor for row extent + static int constexpr kRowFactor = 1; + + /// Minimum divisor for column extent + static int constexpr kColumnFactor = 1; + + /// Minimum divisor for batch size dimension + static int constexpr kBatchFactor = 1; + + /// Tensor layout used in permutation operation + using Layout = cutlass::layout::PackedVectorLayout; + + static std::string name() { + return "NoPermute"; + } + + /// User-friendly description of the permute operation + static std::string desc() { + return "no permutation"; + } + + /// Infer original higher-rank tensor shape from GEMM/BMM matrix extents. + /// For direct (output) permutations, must be a simple reshape of extent. + /// For inverse (input) permutations, must return shape *before* permute operation. + /// In case of NoPermute, simply use a linear (rank 1) view of the memory + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + return Layout::TensorCoord(extent.row() * extent.column() * batch_count); + } + + /// Compute the permuted higher-rank tensor shape from the original shape. + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return s; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row(); + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHCW; + + static std::string name() { + return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 3, 2, 1]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[3], s[2], s[1]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row() / D1; + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = D2; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D1; + int D3 = extent.column() / D2; + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D2; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D2; + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNDHWC; + + static std::string name() { + return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [2, 0, 3, 1, 4]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[2], s[0], s[3], s[1], s[4]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorCWHDN; + + static std::string name() { + return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 4, 1, 3]"; + } + + using Coord = cutlass::Tensor5DCoord; + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[0], s[2], s[4], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..cd11c74b44e899cec472d4cfb9b3a73a1e2dcd20 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/conv2d.py @@ -0,0 +1,177 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a 2d convolution +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np +import torch + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser( + description=("Launch a 2d convolution kernel from Python. " + "See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation.")) +parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution") +parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution") +parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution") +parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution") +parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution") +parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution") +parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70." + +alignment = 1 + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = Conv2dOperation( + conv_kind=cutlass_bindings.conv.Operator.fprop, + iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w), + cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding + cutlass_bindings.MatrixCoord(1, 1), # Strides + cutlass_bindings.MatrixCoord(1, 1), # Dilation + cutlass_bindings.conv.Mode.cross_correlation, + 1, # Split k slices + 1 # Groups +) + +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) +tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) + +tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda") + +alpha = 1. +beta = 0. + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = Conv2dReferenceModule(A, B, C, operation.conv_kind) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) + +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py new file mode 100644 index 0000000000000000000000000000000000000000..d2df3ed8d1f43cebecdd14240e5b819e184885f7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/conv2d.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +from cutlass.backend.conv2d_operation import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +import torch.nn.functional as F + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorC32RSK32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], + help='Data type of computation in the epilogue') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", + "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", + "StridedDgradHorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# conv related +parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], + help="The type of convolution: forward propagation (fprop), \ + gradient of activation (dgrad), gradient of weight (wgrad)") +parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], + ) +parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, + choices=["analytic", "optimized", "fixed_channels", "few_channels"], + help="This option describes iterator algorithm") + +# arguments +parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], + help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ + Parallel is used for parallel splitK.") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") +parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") +parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") +parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") +parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") + + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +stride_support = getattr(StrideSupport, args.stride_support) +conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind) + +operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, stride_support=stride_support, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation,] + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]), + cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, + args.split_k_slices, 1 +) + + +# User-provide inputs +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size( + conv_kind, problem_size +) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size( + conv_kind, problem_size +) +if args.bias: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ).at(3) +else: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +if args.element_a != "int8": + tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) + +if args.element_b != "int8": + tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) + +if args.element_c != "int8": + tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) + +tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode), + split_k_slices=problem_size.split_k_slices +) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +reference_model = Conv2dReferenceModule(A, B, C, conv_kind) + +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) +if (args.activation_function != "identity"): + tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..3494fe53f5fe1066f15e73d7b1190ad0b4fe3fdb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import cutlass_bindings +from bfloat16 import bfloat16 + +import argparse + + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help="This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], + help="This option describes how thread blocks are scheduled on GPU") + +# Argument +parser.add_argument("-p", "--problem_size", + default=[128, 128, 128], nargs=3, type=int, + help="GEMM problem size M, N, K") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, + help="Scaling factor of A * B") +parser.add_argument("-beta", "--beta", default=0.0, type=float, + help="Scaling factor of C") +parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, + choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], + help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ + GemmSplitKParallel is used for parallel splitK") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +pycutlass.compiler.nvcc() + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) + +operation = GemmOperationUniversal( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +if args.gemm_mode == "GemmSplitKParallel": + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +# User-provide inputs + +problem_size = cutlass_bindings.gemm.GemmCoord( + args.problem_size[0], args.problem_size[1], args.problem_size[2]) + +tensor_a_size = args.batch * problem_size.m() * problem_size.k() +if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(bfloat16) + else: + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(getattr(np, args.element_a)) +else: + tensor_A = np.random.uniform( + low=-2, high=2,size=(tensor_a_size,) + ).astype(getattr(np, args.element_a)) + +tensor_b_size = args.batch * problem_size.k() * problem_size.n() +if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(bfloat16) + else: + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(getattr(np, args.element_b)) +else: + tensor_B = np.random.uniform( + low=-2, high=2, size=(tensor_b_size,) + ).astype(getattr(np, args.element_b)) + +if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + tensor_c_size = args.batch * problem_size.n() + elif args.layout_c == "ColumnMajor": + tensor_c_size = args.batch * problem_size.m() + else: + raise ValueError(args.layout_c) + else: + tensor_c_size = args.batch * problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(getattr(np, args.element_c)) +else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + +tensor_D = np.zeros( + shape=(args.batch * problem_size.m() * problem_size.n(),) +).astype(getattr(np, args.element_c)) + +output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=output_op, + gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode), + split_k_slices=args.split_k_slices, batch=args.batch +) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_arguments = ReductionArguments( + operation=reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], + partitions=args.split_k_slices, workspace=arguments.ptr_D, + destination=tensor_D, source=tensor_C, + output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +# run the host reference module +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) + +tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..a3319e60f9a688de49ca4b26a07692794ea71cb3 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -0,0 +1,298 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import csv + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM Grouped kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU. \ + NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ + This parameter is passed in at present to match the APIs of other kernels. The parameter \ + is unused within the kernel") +# precompute mode +parser.add_argument("-pm", "--precompute_mode", + default="Device", type=str, choices=["Host", "Device"], + help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") +# arguments +parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv", + help="path to the csv file contains the problem sizes") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if args.activation_function == "identity": + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +precompute_mode = getattr(SchedulerMode, args.precompute_mode) + +operation = GemmOperationGrouped( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + precompute_mode=precompute_mode +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +pycutlass.compiler.add_module([operation, ]) + +reference_module = ReferenceModule(A, B, C) + +# get problems +problem_sizes = [] +with open(args.problem_size_dir) as csv_file: + reader = csv.reader(csv_file) + for row in reader: + problem_sizes.append( + cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + ) + +problem_count = len(problem_sizes) + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +problem_sizes_coord = [] +tensor_D_refs = [] + +for problem_size in problem_sizes: + if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) + else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + + if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) + else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + + if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + c_size = problem_size.n() + elif args.layout_c == "ColumnMajor": + c_size = problem_size.m() + else: + raise ValueError(args.layout_c) + else: + c_size = problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(getattr(np, args.element_c)) + else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + tensor_D_ref = reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, + args.alpha, args.beta, args.bias) + tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + tensor_D_refs.append(tensor_D_ref) + problem_sizes_coord.append(problem_size) + +arguments = GemmGroupedArguments( + operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +) + +operation.run(arguments) + +arguments.sync() + +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd4113a6e9a8e524927502f5bbb9b2d1c30820e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm.py @@ -0,0 +1,153 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") +parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM") +parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM") +parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70." + +alignment = 8 +assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment) +assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment) +assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment) + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationUniversal( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors +tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16) +tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16) +tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32) +tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32) + +problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k) +alpha = 1. +beta = 0. + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta)) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py new file mode 100644 index 0000000000000000000000000000000000000000..508b0894827315003e46650a55371b4c78d83812 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/40_cutlass_py/gemm_grouped.py @@ -0,0 +1,172 @@ +################################################################################ +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a grouped GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70." + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +alignment = 1 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationGrouped( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, + precompute_mode=SchedulerMode.Device) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Initialize tensors for each problem in the group +problem_sizes = [ + cutlass_bindings.gemm.GemmCoord(128, 128, 64), + cutlass_bindings.gemm.GemmCoord(512, 256, 128) +] +problem_count = len(problem_sizes) + +alpha = 1. +beta = 0. + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +tensor_D_refs = [] + +reference = ReferenceModule(A, B, C) + +for problem_size in problem_sizes: + # Randomly initialize tensors + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16) + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16) + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32) + tensor_D = np.zeros(shape=(m * n,)).astype(np.float32) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + + # Run the reference GEMM + tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + tensor_D_refs.append(tensor_D_ref) + +arguments = GemmGroupedArguments( + operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Compare the CUTLASS result to the host reference result +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a22f12b711bbc9a6dc1a6cc6020dd6df6e3ffe0b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/debug_utils.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (size_t _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (size_t _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..14604f10c368ba1132100f20dd073e45d7afcf2d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "fmha_grouped.h" +#include "gemm_kernel_utils.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK = (int)cutlass::platform::numeric_limits::max(), + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly + > +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + Operator + >::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Iterator; + + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + kSplitKSerial, + Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + +/// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped< + MM0, + MM1, + scalar_t, + accum_t, + output_t, + output_accum_t, + kSingleValueIteration, + GroupScheduleMode_ + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..84f89a5ff4d7885d9d7bbec70cbf1913bc9890d1 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000000000000000000000000000000000..411b5574ec715b053a341e8d5f5dfadff3511555 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,252 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000000000000000000000000000000000..b110abecedaa3324fe360bc919da2ed7a07baba3 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,174 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8bc25462ac4251158d2262fb3f6b61ea0228d09f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_backward_test.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import argparse +import torch +import sys +import os +from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME +import math + + +parser = argparse.ArgumentParser() +parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable") +args = parser.parse_args() + +torch.manual_seed(0) +dtype = torch.float16 +B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128 +causal = True +repeat_count = 100 + +ATOL = { + torch.float: 5e-4, + torch.half: 9.5e-2, + torch.bfloat16: 7e-1, +}[dtype] + +RTOL = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 1e-1, +}[dtype] + + +assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ" + +fmha_bw_binary = args.example_exe +if not os.path.isfile(fmha_bw_binary): + print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""") + sys.exit(1) + +def create_lower_triangular_mask(): + return torch.triu(torch.full( # type: ignore + [1, Mq, Mkv], + dtype=dtype, + fill_value=float("-inf"), + ), diagonal=1) + +def ref_mha_bmk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMK format + q = q.float() + k = k.float() + v = v.float() + + q = q * (1 / q.shape[-1] ** 0.5) + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn += mask + attn_max = attn.max(-1, True).values + attn_norm = (attn - attn_max).exp().sum(-1, True) + attn = attn.softmax(-1) + lse = attn_max + attn_norm.log() + lse = lse.squeeze(2) + return attn @ v, lse + + +def bmhk2bmk(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + +def ref_mha_bmhk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMHK format + assert q.ndim == 4 + + out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]]) + +def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta): + lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension + delta = delta.reshape([-1, delta.shape[-1], 1]) + + # bmhk -> bmk + q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)] + + attn_T = k @ q.transpose(-2, -1) + if mask is not None: + attn_T += mask.transpose(-2, -1) + attn_T = attn_T * (1 / q.shape[-1] ** 0.5) + attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]]) + attn_T = attn_T.exp() + + grad_v = attn_T @ grad_out + + dov = grad_out @ v.transpose(-2, -1) + tmp = (dov - delta) * attn_T.transpose(-2, -1) + tmp = tmp / (q.shape[-1] ** 0.5) + + grad_q = tmp @ k + grad_k = tmp.transpose(-2, -1) @ q + + return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]] + + +print("initializing tensors...") +query = torch.randn([B, Mq, H, K], dtype=dtype) +key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype) +value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype) +mask = create_lower_triangular_mask() if causal else None + +# let PyTorch compute gradients +query.requires_grad_(True) +key.requires_grad_(True) +value.requires_grad_(True) + +print("computing fw...") +out, lse = ref_mha_bmhk(query, key, value, mask=mask) +out = out.to(dtype).contiguous() +grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype) + +print("computing bw with autograd...") +out.backward(grad_out) +scale = (1 / query.shape[-1] ** 0.5) + + +# Additional data needed by the kernel +delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous() +pad_amount = (32 - (lse.shape[2] % 32)) % 32 +lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + +print("computing bw with reference implem...") +gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta) + +with PipedSubprocess(fmha_bw_binary) as bw_kernel: + # Send kernel arguments + bw_kernel.write( + TORCH_DTYPE_NAME[query.dtype], + "scale", scale, + "head_dim", K, + "head_dim_value", Kv, + "num_queries", Mq, + "num_keys", Mkv, + "num_heads", H, + "custom_mask_type", (1 if causal else 0), + "num_batches", B, + "repeat_count", repeat_count, + "num_splits_key", (Mkv // 128), + ) + bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"]) + bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"]) + bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"]) + bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"]) + bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"]) + bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"]) + bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"]) + + if bw_kernel.read() != "OK": + print("Got unexpected output") + print(bw_kernel.subp.communicate()[0]) + sys.exit(0) + + # Read kernel output + gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float() + gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float() + gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float() + runtime_ms = float(bw_kernel.readNamed("runtime_ms")) + +float_ops = B * H * sum([ + # att = Q @ K.transpose + Mq * Mkv * K * 2, + # att @ dO + Mkv * Mq * Kv * 2, + # dov = dO @ V + Mq * Kv * Mkv * 2, + # dov @ K + Mq * K * Mkv * 2, + # dov @ Q + Mq * K * Mkv * 2, +]) +if causal: + float_ops //= 2 + +print(f""" +Fused multi-head attention - backward + batch_size={B} + num_queries={Mq} + num_keys={Mkv} + num_heads={H} + head_dim={K} + head_dim_value={Kv} + + Correctness: + grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()}) + grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()}) + grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()}) + (atol={ATOL} / rtol={RTOL}) + Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops) +""") + +assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h new file mode 100644 index 0000000000000000000000000000000000000000..afc25e43401a7bd75bba62f1e06b33208e11a181 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -0,0 +1,1023 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "fmha_grouped_problem_visitor.h" +#include "gemm_kernel_utils.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "epilogue/epilogue_rescale_output.h" + + +namespace { + static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename MM0_, ///! Structure for computing P = Q @ K + typename MM1_, ///! Structure for computing O = P @ V + typename scalar_t_, + typename accum_t_, + typename output_t_, + typename output_accum_t_, + bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file + GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform +> +struct FMHAGrouped { +public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + static constexpr bool kSupportsDropout = false; + static constexpr bool kSupportsBias = false; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp); + + using ProblemVisitor = FMHAGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes0{nullptr}; + GemmCoord *problem_sizes1{nullptr}; + + int problem_count{0}; + int threadblock_count{0}; + + ElementQ ** ptr_Q{nullptr}; + ElementK ** ptr_K{nullptr}; + ElementP ** ptr_P{nullptr}; + ElementV ** ptr_V{nullptr}; + ElementO ** ptr_O{nullptr}; + ElementOAccum ** ptr_O_accum{nullptr}; + + typename LayoutQ::Stride::LongIndex *ldq{nullptr}; + typename LayoutK::Stride::LongIndex *ldk{nullptr}; + typename LayoutP::Stride::LongIndex *ldv{nullptr}; + typename LayoutO::Stride::LongIndex *ldo{nullptr}; + + // Whether causal masking is to be performed + bool causal{false}; + + // Scale + ElementAccumulator scale{0}; + + // Only used by device-level operator + GemmCoord *host_problem_sizes{nullptr}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + ElementQ ** ptr_Q, + ElementK ** ptr_K, + ElementP ** ptr_P, + ElementV ** ptr_V, + ElementO ** ptr_O, + ElementOAccum ** ptr_O_accum, + typename LayoutQ::Stride::LongIndex *ldq, + typename LayoutK::Stride::LongIndex *ldk, + typename LayoutP::Stride::LongIndex *ldp, + typename LayoutV::Stride::LongIndex *ldv, + typename LayoutO::Stride::LongIndex *ldo, + bool causal, + ElementAccumulator scale, + GemmCoord *host_problem_sizes=nullptr + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_P(ptr_P), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), + ldq(ldq), + ldk(ldk), + ldv(ldv), + ldo(ldo), + causal(causal), + scale(scale), + host_problem_sizes(host_problem_sizes) + { + + } + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementV ** ptr_V; + ElementO ** ptr_O; + ElementOAccum ** ptr_O_accum; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutO::Stride::LongIndex *ldo; + + ElementAccumulator scale; + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(nullptr), + ldk(nullptr), + ldv(nullptr), + ldo(nullptr), + causal(false), + scale(0) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), + ldq(args.ldq), + ldk(args.ldk), + ldv(args.ldv), + ldo(args.ldo), + causal(args.causal), + scale(args.scale) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, tile_count); + threadblock_count = args.threadblock_count; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; + ldq = args.ldq; + ldk = args.ldk; + ldv = args.ldv; + ldo = args.ldo; + causal = args.causal; + scale = args.scale; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + +private: + + // Parameters to be used by an individual tile + struct TileParams { + + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x; + } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + [[maybe_unused]] auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O, + typename OutputTileIterator::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + (int32_t)kKeysPerBlock, num_keys - iter_key_start); + int32_t const& problem_size_0_k = problem_size0.k(); + int32_t const& problem_size_1_n = problem_size1.n(); + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), + params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM) + }; + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + // DISPATCH_BOOL( + // num_keys - iter_key_start >= kKeysPerBlock, + // kFullColumns, + // ([&] { + // // Update `mi` from accum stored in registers + // // Also does accum[i] <- exp(accum[i] - mi) + // iterative_softmax< + // typename MM0::Mma::Operator::IteratorC, + // kFullColumns, + // kIsFirst>( + // accum_o, + // accum, + // mi, + // m_prime, + // s_prime, + // lane_id(), + // thread_id(), + // warp_id(), + // num_keys - iter_key_start, + // iteratorC_tile_offset, + // kSupportsBias ? 1.0f : params.scale); + // })); + // })); + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : params.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kKeepOutputInRF ? 1 + : ceil_div( + (int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)warp_id(), + (int)lane_id()); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + __syncthreads(); // Don't start the next iteration until all threads are done using shared memory. + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kThreadsPerWarp>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h new file mode 100644 index 0000000000000000000000000000000000000000..f88219304b8e239cbb656004ad17636ac98d4e04 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes1 because these determine + // shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..f3a1d4cbc275a166e0575b365413864ffb93c9f3 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..66c099d15bf038fd06fd2b3ca6c03d1de3f7ffe9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..145315e413172e1f29a10b13a9d9d7fa19537006 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,760 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + prologue_done_ = value; + return true; + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + return true; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000000000000000000000000000000000..b967b86c0131ada37a19be623486c61562588246 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,401 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000000000000000000000000000000000..560da450ff54a86e723e11b9aa00597095b73127 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instantiate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7692389c5c9bd77e68cc622d3cc774a7c605f1b5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..f2b94d003119b144cea632b1302bfe024d15806d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -0,0 +1,1955 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tools and utils to store a GEMM output in shmem, and to use that + output as operandA for another GEMM back-to-back +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/default_warp_iterator_from_smem.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + TensorRefB& b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(b_tile, lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Max GEMM problem size in K dimension + int MaxK, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + using WarpIteratorA = WarpIteratorA_; + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async transfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + int kMaxK, + typename WarpIteratorA_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + // Max MMA problem size K + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = WarpIteratorA_; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + kMaxK, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + SmemAccumulatorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..3703257a1730852809e3f5597d134f8146e1f9e2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/mma.h" + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + using BOOL_NAME = std::true_type; \ + F(); \ + } else { \ + using BOOL_NAME = std::false_type; \ + F(); \ + } \ + } + +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define XFORMERS_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ + } +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..930ee46dfe6d1ab8e9482145b0379de8a7eddc6f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instantiates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" +#include "cutlass/platform/platform.h" + +#include "warp_iterator_from_smem.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..7a52e96a36bd5cddf5db65fba130b4b45ff2ff66 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,751 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..a667d675276111d074abd868e96c141d2f2f3829 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..d007f0445b8301a16e371c6f1d9911dfda5ad915 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000000000000000000000000000000000..fa40d850c84528c7cf2e5d9eaa5961774843405d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2119 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000000000000000000000000000000000..18858ab7327a351d5bb742acb09da72bfccedf30 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000000000000000000000000000000000..3f4ebec698b23f85f7d0a0a5450aa21f3a9fc17e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // Note: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h new file mode 100644 index 0000000000000000000000000000000000000000..5cdb7c2145488c585a6665bb6170b3638e28692a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_backward.h @@ -0,0 +1,2554 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#ifdef HAS_PYTORCH +#include +#include +#include +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "debug_utils.h" +#include "gemm_kernel_utils.h" + +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "iterators/epilogue_predicated_tile_iterator.h" + +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = + FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = + kNumThreads * FragmentType::kElements; + static_assert( + FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load( + sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store( + sub_fragment, gmem_ptr, true); + } + } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +template +constexpr int getWarpsPerSmBw() { + bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { + return is_half ? 12 : 8; + } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false, + // Allows to parallelize across keys + bool kEnableSplitKeys_ = true> +struct AttentionBackwardKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem everytime + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert( + !kPreload || + (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = + (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSmBw() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + typename GemmType::OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr auto kOptimalAlignement = cutlass::platform::max( + DefaultConfig::kAlignmentA, + DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = + typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmemN = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = typename cutlass::platform::conditional< + DefaultMmaFromSmemT::kIsTransposedA, + DefaultMmaFromSmemT, + DefaultMmaFromSmemN>::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + static constexpr bool kEnableSplitKeys = kEnableSplitKeys_; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [Mq, nH, K] + scalar_t* key_ptr = nullptr; // [Mk, nH, K] + scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + scalar_t* bias_ptr = nullptr; + lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = -1; + int32_t k_strideM = -1; + int32_t v_strideM = -1; + int32_t bias_strideM = 0; + int32_t gO_strideM = -1; + int32_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + +#ifdef HAS_PYTORCH + // dropout + at::PhiloxCudaState rng_engine_inputs = {0, 0}; +#endif + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE int16_t num_splits_key_device() const { + return kEnableSplitKeys ? gridDim.x : 1; + } + CUTLASS_DEVICE int16_t split_key_device() const { + return kEnableSplitKeys ? blockIdx.x : 0; + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1; + } + }; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part4; + }; + static void print_size() { + // Field size +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) + + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); + printf(" persistent: %db\n", FSZ(persistent)); + printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); + printf(" part1: %db\n", FSZ(part1)); + printf(" bias: %db\n", FSZ(part1.bias)); + printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(part1.zij)); + printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV)); + printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); + printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj)); + printf(" part2: %db\n", FSZ(part2)); + printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); + printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK)); + printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); + printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); + printf(" part3: %db\n", FSZ(part3)); + printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); + printf(" part4: %db\n", FSZ(part4)); + printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part6; + }; + static void print_size() { +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); + printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); + printf(" part1: %db\n", FIELD_SIZEOF(part1)); + printf(" part2: %db\n", FIELD_SIZEOF(part2)); + printf(" part3: %db\n", FIELD_SIZEOF(part3)); + printf(" part4: %db\n", FIELD_SIZEOF(part4)); + printf(" part5: %db\n", FIELD_SIZEOF(part5)); + printf(" part6: %db\n", FIELD_SIZEOF(part6)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform::conditional< + kPreload, + SharedStoragePrologue, + SharedStorageNoPrologue>::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); + XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + if (p.bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM)"); + } + if (p.grad_bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.gB_strideM % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideM)"); + } + XFORMERS_CHECK( + !(p.cu_seqlens_q_ptr && p.bias_ptr), + "CuSeqlen + bias not implemented yet"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "Invalid value for `custom_mask_type`"); + XFORMERS_CHECK( + p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + XFORMERS_CHECK( + kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + XFORMERS_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + XFORMERS_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + XFORMERS_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + XFORMERS_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + XFORMERS_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + XFORMERS_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + XFORMERS_CHECK( + p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + XFORMERS_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + XFORMERS_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + XFORMERS_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + XFORMERS_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + XFORMERS_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (too large)"); + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int32_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } + if (kPrologueQK) { + int32_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = + 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; +#ifdef HAS_PYTORCH + if (kApplyDropout) { + auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &rng_state_init); + } +#endif + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { + output_frags.clear(); + + CUTLASS_PRAGMA_UNROLL + for (int32_t query_start_shifted = getQueryStart(p, key_start); + query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); + query_start_shifted += kBlockSizeI) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + int32_t query_start = query_start_shifted; + if (query_start >= p.num_queries) { + query_start = query_start % getQueryEnd(p); + } + + processBlockIJ( + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init, + warp_id, + lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV( + p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + } + + template + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { + continue; + } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { + gk_ptr[k] = scalar_t(0); + } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int32_t query_start, + int32_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1}); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } + + int32_t num_queries_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); + + auto prologueGradV = [&](int col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), + iterator_dO, + thread_id, + num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), + iterator_Q, + thread_id, + num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), + iterator_A, + iterator_B, + thread_id, + p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + accum = cutlass::multiplies()(scale, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + p.bias_ptr + query_start * p.bias_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } + + // Apply mask + if (p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start; + if (p.custom_mask_type == CausalFromBottomRight) { + shift += p.num_keys - p.num_queries; + } + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key > current_query` + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m > accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + + __syncthreads(); + if (kPrologueGV) { + prologueGradV(0); + } + if (kPrologueDOV) { + prologueDOV(); + } + + MatmulQK::B2bGemm::accumApplyLSEToSmem( + shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; + const int threads_per_row = cutlass::fast_min( + int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); + } + } + } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0}; + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma( + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradV, + iterator_B, + output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem( + shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { + prologueGradQ(0); + } + if (kPrologueGK) { + prologueGradK(0); + } + + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; + auto lane_offset = LambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + // if dropout was used, compute dPij = dPij_dropped * Zij + if (kApplyDropout) { + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) {}); + } + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + // TODO: Otherwise we can get nans as we + // might have infs here (only seen on f16 tho) + if (skipBoundsChecks || + (accum_m < num_queries_in_block && + accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op( + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), + accum, + lane_id, + output_tile_coords); + __syncthreads(); + } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma( + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); + // Output results + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } + } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + storage_contains_zeros, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = + MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = *call_conditional< + kIsTransposedA, + decltype(getTmp), + decltype(getTmpT)>::apply(getTmp, getTmpT, 0); + Mma mma( + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradK, + iterator_B, + output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem( + isLastColumn ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic + static CUTLASS_DEVICE int32_t + getQueryStart(Params const& p, int32_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_DEVICE int32_t + getSmallestQueryForKey(Params const& p, int32_t key_start) { + if (p.custom_mask_type == CausalFromTopLeft) { + return (key_start / kBlockSizeI) * kBlockSizeI; + } else if (p.custom_mask_type == CausalFromBottomRight) { + int first_query = + cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); + return (first_query / kBlockSizeI) * kBlockSizeI; + } + return 0; + }; + + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type == CausalFromTopLeft) { + int32_t last_key_for_block = query_start + kBlockSizeI - 1; + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } else if (p.custom_mask_type == CausalFromBottomRight) { + int32_t last_key_for_block = + query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process + static CUTLASS_DEVICE void incrIteration( + Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (next_query < p.num_queries) { + return; + } + // jump to next key + } + // Next key + next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration( + SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + if (query_start >= p.num_queries || key_start >= p.num_keys) { + return; + } + + static constexpr bool kReloadK = + kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::template prologue( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + + accumulateInGmem( + shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem( + shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = kIsFirst::value + ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + laneFirstCol); + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + + laneFirstCol); + + static constexpr int64_t kMaxIters = + kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = + cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && + (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && + (laneFirstCol + + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { + columnIteration(iter); + } + } else { + int num_iters = + ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { + columnIteration(iter); + } + } + + // Reduce between workers + static_assert( + kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || + kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { + p.delta_ptr[query_start + laneRow] = delta_value; + } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h new file mode 100644 index 0000000000000000000000000000000000000000..ed4e16775951a0062fdc7e2fa435bae13610b470 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/kernel_forward.h @@ -0,0 +1,1322 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#ifdef HAS_PYTORCH +#include +#include +#endif + +#include +#include +#include +#include + +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "debug_utils.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSmFw() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +// If ToBatchHookType_ is supplied other than this default (which is +// never the case in the xformers library) then the user is +// defining the logic which each block uses to find its data to work on, +// with the advance_to_batch function with the following signature. +// It should return false if there is no work to do for this block. +// In general this will not work with saving for backward due to fixed layout +// for logsumexp and incompatible rngs for dropout, so is likely only useful for +// custom inference. +struct DefaultToBatchHook { + template + CUTLASS_DEVICE static bool advance_to_batch( + Params&, + int64_t& /* q_start */, + int64_t& /* k_start */) { + return true; + } +}; + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock_, + int kKeysPerBlock_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true, + typename ToBatchHookType_ = DefaultToBatchHook> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSmFw() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; + + // Scale + accum_t scale = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; + + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; + + int32_t num_batches = 0; + int32_t num_heads = 0; + + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); +#endif + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start = 0, k_start = 0; + // Advance to current batch - in case of different sequence lengths + constexpr bool kToBatchHook = + !cutlass::platform::is_same:: + value; + if (kToBatchHook) { + // Call out to a custom implementation. + if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) { + return false; + } + } else if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset = num_keys - num_queries; + } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + // Only worth doing if they could have been modified above. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } + XFORMERS_CHECK( + p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + +#ifdef HAS_PYTORCH + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } +#endif + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_uniform(warp_id()); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + +#ifdef HAS_PYTORCH + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys_absolute + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } +#endif + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)my_warp_id, + (int)my_lane_id); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + my_warp_id, + my_lane_id); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..536bdb4305d112343f5bf348ff6a93e68a520d0b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/piped_subprocess.py @@ -0,0 +1,144 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import List +import torch +import subprocess +import sys +import tempfile +import os +import numpy as np + + +TORCH_DTYPE_NAME = { + torch.float32: "f32", + torch.float16: "f16", + torch.bfloat16: "b16" +} +NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()} + +def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor: + # PyTorch >= 2.0 + if hasattr(tensor, 'untyped_storage'): + return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage()) + return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped()) + +class PipedSubprocess: + def __init__(self, binary: str) -> None: + self.binary = binary + self.tempdir_ctx = tempfile.TemporaryDirectory() + + def __enter__(self) -> "PipedSubprocess": + self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0) + self.tempdir = self.tempdir_ctx.__enter__() + self.file_counter = 0 + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb) + + def temp_filename(self, suffix: str) -> str: + self.file_counter += 1 + return os.path.join(self.tempdir, f"{self.file_counter}{suffix}") + + def write(self, *args) -> None: + for a in args: + self.subp.stdin.write(str(a) + " ") + + def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None: + print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}") + tensor_u8 = _tensor_from_storage(tensor, torch.uint8) + self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0]) + filename = self.temp_filename(f"{name}.tensor") + assert tensor.storage_offset() == 0 + with open(filename, "wb+") as fd: + fd.write(bytes(tensor_u8.numpy())) + self.write("file", filename) + self.write("tensor_end") + + for stride_name, stride_value in zip(stride_names, tensor.stride()): + self.write(stride_name, stride_value) + + def readTensor(self, name, stride_name, shape) -> torch.Tensor: + tmpfile = self.temp_filename(f"{name}.tensor") + self.write("tmpfile", tmpfile) + + self.readExpect("tensor_begin") + dtype_str, name = self.read().split(":") + print(f"C++->Py : {dtype_str}:{name}") + u8len = int(self.read()) + dtype = NAME_TORCH_DTYPE[dtype_str] + + self.readExpect("file") + self.readExpect(tmpfile) + + with open(tmpfile, "rb") as fd: + data = fd.read(u8len) + # `np.array` is not strictly needed, but avoids a torch warning + tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len) + self.readExpect("tensor_end") + + tensor = _tensor_from_storage(tensor_u8, dtype) + strides = [] + for sn in stride_name: + self.readExpect(sn) + strides.append(int(self.read())) + if len(strides) != shape: + strides.append(1) + assert len(strides) == len(shape), name + return torch.as_strided(tensor, shape, strides) + + def readNamed(self, name: str): + self.readExpect(name) + return self.read() + + def readExpect(self, what: str) -> None: + r = self.read() + if r != what: + raise ValueError(f"Read {r} but expected {what}") + + def read(self): + read_all = [] + # Skip initial whitespace + while True: + r = self.subp.stdout.read(1) + if r not in [' ', "\n"]: + read_all.append(r) + break + # Read data + while True: + r = self.subp.stdout.read(1) + if r in [' ', "\n"]: + break + read_all.append(r) + return ''.join(read_all) + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..048c1e019b07495a9f529e9e3c678471a98f5e98 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8f76fe1f30a2987061571fc948723ff6d22a4c89 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" + +// #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "fused_bias_act_epilogue.h" +#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" +#include "output_tile_thread_map_for_fused_bias.h" +#include "default_thread_map_tensor_op_for_fused_bias.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultFusedBiasActEpilogueTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + OutputOp + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..35fe758ae707180456cc695cbe0ffabb6cd065c5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + +*/ + +#pragma once + +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines the optimal thread map for TensorOp accumulator layouts +template < + typename ThreadblockShape_, + typename WarpShape_, + int PartitionsK, + typename Element_, + int ElementsPerAccess +> +struct DefaultThreadMapTensorOpForFusedBias { + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + static int const kPartitionsK = PartitionsK; + using Element = Element_; + static int const kElementsPerAccess = ElementsPerAccess; + + // + // Definitions + // + + struct Detail { + + /// Tensor Operations fundamentally perform operations on 8 rows + static int const kTensorOpRows = 8; + static int const kWarpSize = 32; + + static_assert( + !(ThreadblockShape::kM % WarpShape::kM) && + !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); + + /// Number of warps + using WarpCount = gemm::GemmShape< + ThreadblockShape::kM / WarpShape::kM, + ThreadblockShape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Number of participating threads + static int const kThreads = WarpCount::kCount * kWarpSize; + }; + + // + // ThreadMap + // + + /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap + using Type = OutputTileOptimalThreadMapBiasAct < + OutputTileShape, + OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, + Detail::kThreads, + kElementsPerAccess, + sizeof_bits::value + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..04d988cca4aab92c67bab562a8120bb32a68ff43 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -0,0 +1,213 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator without splitk +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename OutputOp_ ///< Output operator +> +class FusedBiasActEpilogue { + +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using OutputOp = OutputOp_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + +public: + + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +public: + + /// Constructor + CUTLASS_DEVICE + FusedBiasActEpilogue( + ){ } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + bool need_bias = output_op.is_source_needed(); + + if (need_bias) + compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); + else + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + + + } + + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + } + + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; + + + source_fragment.clear(); + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + source_iterator.load(source_fragment); + ++source_iterator; + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment, source_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + + CUTLASS_DEVICE + void compute_source_no_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + +}; + + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..c92307f5774652b54141970f7bf8ee878e01a060 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fast_math.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// RowArrangement determines how one or more warps cover a region of consecutive rows. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize, + bool Is2dTile +> +struct RowArrangementBiasAct; + +/// RowArrangement in which each warp's access is a 1D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + static int const kWarpSize = 32; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + static int const kIterationsRow = 1; + static int const kDeltaRow = 1; + static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; + static int const kDeltaColumn = kWarpSize * kElementsPerAccess; + + static int const kAccessWidth = kWarpSize; + static int const kAccessRows = 1; + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = WarpsRemaining; +}; + +/// RowArrangement in which each warp's access is a 2D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + + static int const kMemoryAccessSize = 4;//128; + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + struct Detail { + static int const kShapeRow = Shape::kRow / WarpsRemaining; + static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; + + static int const kTargetMemoryAccessWidth = + kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); + + static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; + }; + + static int const kAccessWidth = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + kWarpSize / Detail::kShapeRow + : const_min( + Detail::kShapeWidth, + const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) + )); + + static int const kAccessRows = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + Detail::kShapeRow + : const_min(Shape::kRow, kWarpSize / kAccessWidth)); + + static int const kIterationsRow = Detail::kShapeRow / kAccessRows; + static int const kDeltaRow = kAccessRows; + + static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; + static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; + + static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); + static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); + static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); + + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = 1; +}; + +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D space across warps to achieve several performance +/// objectives: +/// +/// - coalesced memory accesses in units of 16 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template < + typename Shape_, + typename Count_, + int Threads, + int ElementsPerAccess, + int ElementSize +> +struct OutputTileOptimalThreadMapBiasAct { + + using Shape = Shape_; + using Count = Count_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail { + + // Clusters + static int const kIterationsCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kCluster / kWarpCount + : 1); + + static int const kDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kCompactedDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kWarpPartitionsCluster = + ((Shape::kCluster > kWarpCount) ? + kWarpCount + : kWarpCount / Shape::kCluster); + + static int const kWarpsRemainingForGroups = + ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); + + // Groups + static int const kIterationsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kGroup / kWarpsRemainingForGroups + : 1); + + static int const kDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kCompactedDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kWarpPartitionsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + static int const kWarpsRemainingForRows = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + // Rows + using RowArrangement = detail::RowArrangementBiasAct< + Shape, + kWarpsRemainingForRows, + kElementsPerAccess, + kElementSize, + (Shape::kRow > kWarpsRemainingForRows) + >; + + // Warp partitions + using WarpPartitions = OutputTileShape< + RowArrangement::kWarpPartitionsColumn, + RowArrangement::kWarpPartitionsRow, + kWarpPartitionsGroup, + kWarpPartitionsCluster, + 1>; + + static int const kAccessWidth = RowArrangement::kAccessWidth; + static int const kAccessRows = RowArrangement::kAccessRows; + }; + + // + // Output + // + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kDeltaGroup, + Detail::kDeltaCluster, + 1>; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; + int group_offset = group_idx * Shape::kRow * Count::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + return MatrixCoord( + cluster_offset + group_offset + row_offset + lane_row_offset, + (column_offset + lane_col_offset) * kElementsPerAccess + ); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h new file mode 100644 index 0000000000000000000000000000000000000000..88259d102e3967905e9a6189a8fd3d6b56ba67ee --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FusedBiasActFragmentIteratorTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FusedBiasActFragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } + /// Stores a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void store(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + accumulators_[accumulator_access_offset] = frag_ptr[n]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cd4417bbb56834835d1dfc9f812a8e2ff628e9df --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of the accumulation tile shape (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Whether beta is zero + bool IsBetaZero_ > +class MmaTensorOpPureFragmentIterator; + + +// Partial specialization for col-major accumulator tile +// And Element type is the same as Accumulator Element type + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType src_fragment; + src_fragment.clear(); + + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; n++) { + for (int m = 0; m < MmaIterations::kRow; m++) { + int accumulator_access_offset = + (n + index_n) * AccumulatorIterations::kRow + m + index_m; + + frag_ptr[n * MmaIterations::kRow + m].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; + // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +// Partial specialization for row-major accumulator tile + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Accumulator Element type + using ElementAccumulator = ElementAccumulator_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + using FragmentAccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + + FragmentAccessType src_fragment; + src_fragment.clear(); + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; m++) { + for (int n = 0; n < MmaIterations::kColumn; n++) { + int accumulator_access_offset = + (m + index_m) * AccumulatorIterations::kColumn + n + index_n; + + frag_ptr[m * MmaIterations::kColumn + n].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py new file mode 100644 index 0000000000000000000000000000000000000000..15b10296d1bdf9c59b21fd91497881601bb91617 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_turing_and_volta as api_generator +import gen_sample as sample_creater +import gen_cmake as cmake_creater +import gen_verify as verify_creater +import gen_device as b2b_fused_generator +import replace_fix_impl_header + +import argparse +import os +import json + + +parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels") +parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate") +parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name") +parser.add_argument("--output-dir", default="", help="Specifies the output dir") +parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir") +parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.") +args = parser.parse_args() + +gen_name = args.gen_name + +cutlass_deps_dir = args.cutlass_dir + +output_dir = args.output_dir +output_dir += "/" + +cutlass_deps_root = args.gen_include_cutlass_dir +if cutlass_deps_root == '': + cutlass_deps_root = cutlass_deps_dir + "/include/" +cutlass_deps_root +='/' + + +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +if not os.path.exists(output_dir + "/" + "auto_gen"): + os.mkdir(output_dir + "/" + "auto_gen") + +if not os.path.exists(output_dir + "/" + "fixed_impl"): + os.mkdir(output_dir + "/" + "fixed_impl" ) + +if not os.path.exists(output_dir + "/" + "sample"): + os.mkdir(output_dir + "/" + "sample" ) + +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock") + +with open(args.config_file, 'r') as infile: + gemm_info_dict = json.load(infile) + +keys = sorted(gemm_info_dict.keys()) +fuse_gemm_info = [gemm_info_dict[k] for k in keys] + + +for_cutlass_gen_user_include_header_file = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", +] + +for_fused_wrapper = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", + "auto_gen/device/" + gen_name + ".h", + cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h", + cutlass_deps_root + "cutlass/cutlass.h", +] + +# Copy fixed implementation to the output directory +fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root) +fix_impl.gen_code() + +auto_gen_output_dir = output_dir + "/auto_gen/" +project_root = "" +turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir) +turing_plus.gen_code(75, 'hmma1688', False) + +api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +api.gen_code() + +# Generate C++ sample +os.system("cp ../leaky_bias.h " + output_dir + "/sample/") +os.system("cp ../utils.h " + output_dir + "/sample/") + +sample_dir = output_dir + "/sample/" +sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir) +sample.gen_cpp_sample() + +cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir) +cmake_gen.gen_code() + +verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +verify.gen_code() diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py new file mode 100644 index 0000000000000000000000000000000000000000..c99ac9ae2c383ae030621280eac899f453ab517b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py @@ -0,0 +1,131 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +class gen_build_sys: + def __init__(self, cutlass_deps_dir, output_dir = "../"): + self.output_dir = output_dir + self.cutlass_deps_dir = cutlass_deps_dir + + def gen_top(self): + code = "" + code += '''\ +# Auto Generated code - Do not edit. + +cmake_minimum_required(VERSION 3.8) +project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA) +find_package(CUDAToolkit) +set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}}) +set(CUTLASS_PATH \"{cutlass_deps_dir}/include\") +set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\") +list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}}) +'''.format(cutlass_deps_dir=self.cutlass_deps_dir) + + code += '''\ +set(GPU_ARCHS \"\" CACHE STRING + \"List of GPU architectures (semicolon-separated) to be compiled for.\") + +if(\"${GPU_ARCHS}\" STREQUAL \"\") + set(GPU_ARCHS \"70\") +endif() + +foreach(arch ${GPU_ARCHS}) + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\") + if(SM STREQUAL 70 OR SM STREQUAL 75) + set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\") + set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\") + endif() +endforeach() + +set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\") +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\") + +set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +if(CMAKE_CXX_STANDARD STREQUAL \"11\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\") +endif() + +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\") + +set(COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDAToolkit_INCLUDE_DIRS} +) + +set(COMMON_LIB_DIRS + ${CUDAToolkit_LIBRARY_DIR} +) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH}) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH}) +''' + code += '''\ +include_directories( + ${COMMON_HEADER_DIRS} +) + +link_directories( + ${COMMON_LIB_DIRS} +) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-DGOOGLE_CUDA=1) + +add_executable(sample + sample/sample.cu + one_api.cu +) +target_link_libraries(sample PRIVATE + -lcudart + -lnvToolsExt + ${CMAKE_THREAD_LIBS_INIT} +) + +if(NOT DEFINED LIB_INSTALL_PATH) + set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR}) +endif() +''' + return code + + def gen_code(self): + top_code = self.gen_top() + with open(self.output_dir + "CMakeLists.txt", "w") as f: + f.write(top_code) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py new file mode 100644 index 0000000000000000000000000000000000000000..5abd22c81ee48ae5ce5bbbb88c8c80d3dafe5025 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ast + +fuse_gemm_info = [ + { + 'epilogue': { + 'tp': 'LeakyRelu', #'CustomizedLeaky_RELU' + 'bias': {'addbias': False, 'bias_tp': 'mat'}, + 'args': [('float', 'leaky_alpha', 1.3), ], + 'func': ''' +y = max(leaky_alpha * x, x) +y = y * x + ''' + } + }, + +] +class AnalysisNodeVisitor(ast.NodeVisitor): + def visit_Import(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_ImportFrom(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Assign(self,node): + print('Node type: Assign and fields: ', node._fields) + # print('Node type: Assign and targets value: ', node.targets, node.value) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_BinOp(self, node): + print('Node type: BinOp and fields: ', node._fields) + print('node op: ', type(node.op).__name__) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Expr(self, node): + print('Node type: Expr and fields: ', node._fields) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Num(self,node): + print('Node type: Num and fields: ', node._fields) + print('Node type: Num: ', node.n) + + def visit_Name(self,node): + print('Node type: Name and fields: ', node._fields) + print('Node type: Name and fields: ', type(node.ctx).__name__, node.id) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Str(self, node): + print('Node type: Str and fields: ', node._fields) + +class CodeVisitor(ast.NodeVisitor): + def visit_BinOp(self, node): + if isinstance(node.op, ast.Add): + node.op = ast.Sub() + self.generic_visit(node) + + def visit_Assign(self, node): + print('Assign %s' % node.value) + self.generic_visit(node) + + def visit_Name(self, node): + print("Name:", node.id) + self.generic_visit(node) + + + def visit_FunctionDef(self, node): + print('Function Name:%s'% node.name.op) + self.generic_visit(node) + func_log_stmt = ast.Print( + dest = None, + values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)], + nl = True, + lineno = 0, + col_offset = 0, + ) + node.body.insert(0, func_log_stmt) + +visitor = AnalysisNodeVisitor() + +code = \ +''' + +a=max(leaky_alpha * x, x +1) + +''' + +visitor.visit(ast.parse(code)) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py new file mode 100644 index 0000000000000000000000000000000000000000..6cd01ef16bb078e627cfaf910911b2c16baab76e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -0,0 +1,469 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import * + +import helper +import gen_ir + +import gen_kernel as gen_ker + + +class gen_device: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.raw_gemm_info = fuse_gemm_info + self.b2b_num = len(fuse_gemm_info) + self.user_header_file = user_header_file + self.args = {} + # device arg struct memebr + self.arg_member = [] + self.gen_class_name = gen_class_name + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int} + + self.file_name = output_dir + "/device/" +gen_class_name +".h" + self.sample_dir = output_dir + + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + self.this_file_root = output_dir + "/device/" + + self.first_use_1stage = False + + ## gen kernel + self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root) + + + def __check_arg_type(self, temp_arg): + if temp_arg in self.__tempalate_arg_list.keys(): + return self.__tempalate_arg_list[temp_arg] + + find_sub = False + for candidate_arg in self.__tempalate_arg_list.keys(): + if (temp_arg.find(candidate_arg) != -1): + return self.__tempalate_arg_list[candidate_arg] + + return 'typename' + + # def gen_B2b2bGemm_class(): + def set_arch(self, sm_cap, mma_tp): + if sm_cap == 75 or sm_cap == 80 or sm_cap == 86: + self.arch = "cutlass::arch::Sm" + str(sm_cap) + + if mma_tp is 'hmma1688': + self.mma_shape = [16, 8, 8] + self.mma_tp = 'hmma' + elif mma_tp is 'imma8816': + self.mma_tp = 'imma' + self.mma_shape = [8, 8, 16] + else: + return 0 + + def gen_include_header(self): + code = '''\ +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_root}cutlass/cutlass.h\" +#include \"{cutlass_root}cutlass/numeric_types.h\" +#include \"{cutlass_root}cutlass/arch/arch.h\" +#include \"{cutlass_root}cutlass/device_kernel.h\" + +#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\" + +#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{project_root}../kernel/b2b_gemm.h\" +#include \"{project_root}../kernel/default_b2b_gemm.h\" +'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root) + include_user_header = "" + for header in self.user_header_file: + include_user_header += "#include \"" + header + "\"\n" + return code + include_user_header + + def gen_code(self, sm_cap, mma_tp, ifprint = True): + self.set_arch(sm_cap, mma_tp) + + self.update_b2b_args() + print(self.fuse_gemm_info) + self.update_b2b_class_template_args() + + func_code = self.gen_all_func() + member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n" + + gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code) + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code))) + + if ifprint: + print(code) + + print("[INFO]: Gen device code output Dir: is ", self.file_name) + with open(self.file_name, 'w+') as f: + f.write(code) + + + gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage) + print(gen_kernel) + + def update_b2b_class_template_args(self): + for arg in self.args.keys(): + self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]]) + + def update_b2b_args(self): + + self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp']) + self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format']) + + cnt = 0 + + warp_M_tile = 32 + + # Determine maximum N_tile + Max_Ntile = 0 + for layer in self.fuse_gemm_info: + n_tile = layer['mnk'][1] + if n_tile > Max_Ntile: + Max_Ntile = n_tile + if Max_Ntile >= 256: + warp_M_tile = 16 + + stages_temp = [] + + for layer in self.fuse_gemm_info: + cnt_str = str(cnt) + B_tp_str= 'ElementB' + cnt_str + B_format_str = 'LayoutB' + cnt_str + C_tp_str= 'ElementC' + cnt_str + C_format_str = 'LayoutC' + cnt_str + Acc_str = 'ElementAccumulator' + cnt_str + + self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp']) + self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format']) + self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp']) + self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format']) + self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp']) + + + mnk = layer['mnk'][:] + + tile_mnk = mnk[:] + + tile_mnk[2] = 32 # force the ktile is 32 + + #N tile gen + if mnk[1] > 1024: + assert(0) + elif mnk[1] > 512: + tile_mnk[1] = 1024 + elif mnk[1] > 256: + tile_mnk[1] = 512 + elif mnk[1] > 128: + tile_mnk[1] = 256 + elif mnk[1] > 64: + tile_mnk[1] = 128 + elif mnk[1] > 32: + tile_mnk[1] = 64 + else : + tile_mnk[1] = 32 + + if tile_mnk[1] == 512: + stages_temp.append(1) + else: + stages_temp.append(2) + + tile_mnk[0] = 4 * warp_M_tile + + + + epilogue_setted_type = helper.get_epilogue_tp(layer) + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + epilogue_str = 'EpilogueOutputOp' + cnt_str + if cnt != len(self.fuse_gemm_info) - 1: + n = layer['mnk'][1] + Fragments = tile_mnk[1] // 8 * 2 + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "" + else: + n = layer['mnk'][1] + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "" + + + + ThreadBlockShape_str = 'ThreadblockShape' + cnt_str + + self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + + WarpShape_str = 'WarpShape' + cnt_str + tile_mnk[0] = warp_M_tile + self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + cnt += 1 + + + self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp']) + self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format']) + + self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape) + self.args['OperatorClass'] = 'arch::OpClassTensorOp' + self.args['ArchTag'] = self.arch + self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle' + + + for i in range(self.b2b_num): + self.args[helper.var_idx('Stages', i)] = "2" + + self.args['AlignmentA'] = str(8) + self.args['AlignmentB'] = str(8) + self.args['SplitKSerial'] = 'false' + self.args['Operator'] = 'typename DefaultGemmConfiguration::Operator' + self.args['IsBetaZero'] = 'false' + + + def gen_using_kernel(self): + code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n" + code += " " + "ElementA,\n" + code += " " + "LayoutA,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementB", i) + ",\n" + code += " " + helper.var_idx("LayoutB", i) + ",\n" + code += " " + helper.var_idx("ElementC", i) + ",\n" + code += " " + helper.var_idx("LayoutC", i) + ",\n" + code += " " + helper.var_idx("ElementAccumulator", i) + ",\n" + code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n" + code += " " + helper.var_idx("ThreadblockShape", i) + ",\n" + code += " " + helper.var_idx("WarpShape", i) + ",\n" + + code += " " + "ElementD,\n" + code += " " + "LayoutD,\n" + code += " " + "InstructionShape,\n" + code += " " + "OperatorClass,\n" + code += " " + "ArchTag,\n" + code += " " + "ThreadblockSwizzle,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("Stages", i) + ",\n" + + + code += " " + "AlignmentA,\n" + code += " " + "AlignmentB,\n" + code += " " + "SplitKSerial,\n" + code += " " + "Operator,\n" + code += " " + "IsBetaZero_\n" + + code += ">::B2bGemmKernel;\n\n" + + return code + + def gen_args(self): + + def gen_arg_member(b2b_num): + data_members = [] + + for i in range(b2b_num): + member_type = "GemmCoord" + member_name = "problem_size_" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = "ref_A0" + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "TensorRef" + member_name = "ref_B" + str(i) + data_members.append((member_type, member_name)) + member_type = "TensorRef" + member_name = "ref_C" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = helper.var_idx("ref_D", b2b_num - 1) + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "typename EpilogueOutputOp" + str(i) + "::Params" + member_name = "epilogue" + str(i) + data_members.append((member_type, member_name)) + + data_members.append(('int', 'batch_count')) + + return data_members + + def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (): " + for i in range(inital_param_num): + final_param = ',' + if i == inital_param_num - 1: + final_param = '{ }' + constructs_code += data_members[i][1] + inital_value + final_param + + constructs_code += "\n" + return constructs_code + + def gen_arg_struct_ctor(struct_name, data_members): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (\n" + cnt = 0 + param_num = len(data_members) + for param in data_members: + final = ',\n' + if cnt == param_num - 1: + final = '\n):\n' + constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final + cnt += 1 + + cnt = 0 + for param in data_members: + final = '),\n' + if cnt == param_num - 1: + final = ") { }\n" + constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final + cnt += 1 + + constructs_code += "\n" + return constructs_code + + # (variable type, variable name) + struct_member = gen_arg_member(self.b2b_num) + self.arg_member = struct_member + + codeBody = "" + for each_member in struct_member: + codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n" + + codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n" + codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n" + struct_code = gen_ir.gen_struct("Arguments", codeBody) + return struct_code + + def gen_func_constructs(self): + code = self.gen_class_name +"() {}" + return code + + def gen_func_initialize(self): + code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \ + "// Determine grid shape\n" + \ + "ThreadblockSwizzle threadblock_swizzle;\n" + \ + "cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \ + " args.problem_size_0, \n" + \ + " { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \ + " args.batch_count);\n" + \ + "// Initialize the Params structure\n" + \ + "params_ = typename B2bGemmKernel::Params{\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.problem_size_", i) + ",\n" + code += " grid_shape,\n" + \ + " args.ref_A0.non_const_ref(),\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n" + code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n" + + code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.epilogue", i) + ",\n" + + code += " args.batch_count\n" + code += "};\n" + \ + "return Status::kSuccess;\n" + \ + "}\n" + return code + + def gen_func_run(self): + code = "Status run(cudaStream_t stream = nullptr) {\n" + \ + "\n" + \ + " ThreadblockSwizzle threadblock_swizzle;\n" + \ + "\n" + \ + " dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \ + " dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \ + "\n" + \ + " cudaError_t result;\n" + \ + "\n" + \ + " int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \ + " if (smem_size >= (48 << 10)) {\n" + \ + " result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \ + "\n" + \ + " if (result != cudaSuccess) {\n" + \ + " return Status::kErrorInternal;\n" + \ + " }\n" + \ + " }\n" + \ + " cutlass::Kernel<<>>(params_);\n" + \ + " result = cudaGetLastError();\n" + \ + " return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \ + " }\n" + + return code + def gen_func_operator(self): + opeartor_with_arg_code = "Status operator()(\n" + \ + " Arguments const &args,\n" + \ + " void *workspace = nullptr,\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = initialize(args, workspace);\n" + \ + " \n" + \ + " if (status == Status::kSuccess) {\n" + \ + " status = run(stream);\n" + \ + " }\n" + \ + " return status;\n" + \ + "}\n" + operator_code = "Status operator()(\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = run(stream);\n" + \ + " return status;\n" + \ + "}\n" + return opeartor_with_arg_code + "\n" + operator_code + + def gen_all_func(self): + return self.gen_using_kernel() + "\n" + \ + self.gen_args() + "\n" + \ + self.gen_func_constructs() + "\n" + \ + self.gen_func_initialize() + "\n" + \ + self.gen_func_run() + "\n" + \ + self.gen_func_operator() diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py new file mode 100644 index 0000000000000000000000000000000000000000..28df5665cd25e1c490b10c51cb556e5397241ca2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py @@ -0,0 +1,249 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper + + +indentation = " " + + +def append_word(word): + code = "" + code += word + code += " " + return code + + +def gen_namespace(namespace, codeBody): + code_gen = "namespace " + namespace + " {\n" + code_gen += codeBody + code_gen += "} // namespace " + namespace + "\n" + return code_gen + + +def gen_expression(type, lval, rval = None): + code_gen = "" + code_gen += append_word(type) + code_gen += append_word(lval) + if rval is not None: + code_gen += append_word("=") + code_gen += append_word(rval) + return code_gen + + +def gen_class(name, codeBody, inheritance_code = None): + code_gen = "" + if inheritance_code is None: + code_gen = "class " + name + "{\n" + else: + code_gen = "class " + name + " : "+ inheritance_code + "{\n" + code_gen += codeBody + code_gen += "}; // class " + name + "\n" + return code_gen + + +def gen_struct(name, codeBody, specialized = None): + specialized_code = "" + if specialized is not None: + specialized_code = "<" + specialized + ">" + code_gen = "struct " + name + specialized_code + "{\n" + code_gen += codeBody + code_gen += "}; // struct " + name + "\n" + return code_gen + + +def gen_template_arg(arg_type, arg_name, default_val = None): + rval = None + if default_val is not None: + rval = str(default_val) + + arg_typename = "" + if arg_type is int: + arg_typename = "int" + elif arg_type is bool: + arg_typename = "bool" + else: + arg_typename = "typename" + + internal_arg_name = arg_name + "_" + + code_gen = indentation + code_gen += gen_expression(arg_typename, internal_arg_name, rval) + + return code_gen + + +def gen_template_args(args, set_default = True): + arg_len = len(args) + cnt = 1 + code_gen = "" + for arg_tuple in args: + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + arg_default_val = None + if len(arg_tuple) == 3 and set_default: + arg_default_val = arg_tuple[2] + + code_gen += gen_template_arg(arg_type, arg_name, arg_default_val) + if cnt != arg_len: + code_gen += ",\n" + cnt += 1 + + return code_gen + + +def gen_template_head(args, set_default = True): + code_gen = "template <\n" + code_gen += gen_template_args(args, set_default) + code_gen += ">\n" + return code_gen + + +def export_template_args(args): + code_gen = "public:\n" + for arg_tuple in args: + code_gen += indentation + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + internal_arg_name = arg_name + "_" + + typename = "" + if arg_type is int: + typename = "static int const" + elif arg_type is bool: + typename = "static bool const" + else: + typename = "using" + + code_gen += gen_expression(typename, arg_name, internal_arg_name) + code_gen += ";\n" + return code_gen + + +def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None): + code_gen = "" + + code_gen += gen_template_head(args, set_default) + code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code) + + return code_gen + + +def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True): + code_gen = "" + code_gen += gen_template_head(args, set_default) + code = export_template_args(args) + codeBody + if export_args is False: + code = codeBody + code_gen += gen_struct(struct_name, code , speicalized) + + return code_gen + + +def gen_declare_template_struct(name, *params): + code = name + "<" + cnt = 0 + param_num = len(params) + for param in params: + final = ", " + if cnt == param_num - 1: + final = "" + code += param + final + cnt += 1 + code += ">;\n" + return code + + +def filtered_param(params, name_and_value_pair, keep_ = False): + rtn_template_args = [] + speicalized_template_args = [] + + for param in params: + param_name = "" + if len(param) >= 1: + param_name = param[1] + else: + param_name = param[0] + + hit_flag = False + set_value = "" + for n_v_pair in name_and_value_pair: + + filter_name = n_v_pair[0] + set_value = n_v_pair[1] + + if param_name == (filter_name + "_") or param_name == filter_name : + hit_flag = True + break + + + if hit_flag is False: + rtn_template_args.append(param) + + if hit_flag is True: + speicalized_template_args.append(set_value) + else: + if keep_ is True: + speicalized_template_args.append(param_name + "_") + else: + speicalized_template_args.append(param_name) + + + specialized_template_arg_str = helper.list_2_string(speicalized_template_args) + + return rtn_template_args, specialized_template_arg_str + + +def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True): + code = "void " + func_name + "(\n" + for arg in arg_lists: + arg_tp = arg[0] + arg_nm = arg[1] + code += " " + arg_tp + " " + arg_nm + ",\n" + code += "cudaStream_t stream)" + if only_declare : + return code + code += "{\n" + + code += code_body + "\n" + code += "}\n" + return code + + +def indent_level(code, level = 0): + rtn_code = "" + for i in range(level): + rtn_code += " " + + rtn_code += code + + return rtn_code diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py new file mode 100644 index 0000000000000000000000000000000000000000..af901ac8523a93611763cdc25c7fd06355936afe --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py @@ -0,0 +1,476 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper +import gen_threadblock as gen_tb + + +class gen_default_Gemm: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_B2bMma(self, specialized_template_args): + code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n" + code += specialized_template_args + code += ">::ThreadblockB2bMma;\n" + + # print(code) + return code + + def gen_epilogue(self): + epilogue_code = "" + epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n" + + epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n" + epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n" + epilogue_code += ">::Epilogue;\n" + + epilogue_code += "using B2bGemmKernel = kernel::B2bGemm;\n\n" + + return epilogue_code + + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/layout/matrix.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" + +#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\" +#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" + +#include \"../kernel/b2b_gemm.h\" +#include \"../threadblock/default_b2b_mma.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False) + + + filter_list = [] + filter_list.append(('Stages', 2)) + filter_list.append(("OperatorClass", "arch::OpClassTensorOp")) + filter_list.append(("ArchTag", "arch::Sm75")) + + for i in range(self.b2b_num): + filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor")) + + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True) + + + B2bMma_code = self.gen_B2bMma(speicalized_template_args) + epilogue_and_rest_code = self.gen_epilogue() + + gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False) + + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code))) + + return self.gen_include_header() + code + + +class gen_Kernel: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2bnum = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_Params(self): + gen_param = "" + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n" + gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n" + gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n" + gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n" + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n" + if i == self.b2bnum - 1: + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + else: + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + + + + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n" + + gen_param += " " + 'int batch_count' + ";\n" + gen_param += " " + 'int gemm_k_iterations_0' + ";\n" + + + return gen_param + + def gen_Memberfunc(self): + code_default = "\nCUTLASS_HOST_DEVICE\n" + code_default += "Params()" + + code_default += " { } \n\n" + + code_construct = "\nCUTLASS_HOST_DEVICE\n" + code_construct += "Params(\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n" + + code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n" + + code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n" + if i == self.b2bnum - 1: + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n" + else: + code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n" + + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n" + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n" + + code_construct += " " + "int batch_count = 1\n" + + code_construct += "):\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n" + + code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n" + code_construct += " " + "params_A0(ref_A0.layout()),\n" + code_construct += " " + "ref_A0(ref_A0),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n" + code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n" + + code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n" + + code_construct += " " + "batch_count(batch_count) {\n" + code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n" + + code_construct += "}\n" + + return code_default + code_construct + + def gen_using(self): + code_using = "" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n" + + code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n" + + + code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n" + code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n" + + code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc()) + + code_using += "union SharedStorage {\n" + code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n" + code_using += " " + "typename Epilogue::SharedStorage epilogue;\n" + code_using += "};\n" + + return code_using + + def gen_can_implement(self): + gen_code = "" + return gen_code + + def gen_operator_and_constr(self): + ctr_code = "CUTLASS_HOST_DEVICE\n" + ctr_code += self.gen_class_name + "() { } \n\n" + operator_code = "CUTLASS_DEVICE\n" + operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n" + operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n" + operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n" + operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n" + operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n" + operator_code += " " + " " + "return;\n" + operator_code += " " + "}\n" + + operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n" + operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n" + operator_code += " " + " " + "0\n" + operator_code += " " + "};\n" + + for i in range(self.b2bnum): + operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n" + operator_code += " " + " " + "0,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n" + operator_code += " " + "};\n" + + operator_code += " " + "int thread_idx = threadIdx.x;\n\n" + + operator_code += " " + "MatrixCoord threadblock_offset(\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n" + operator_code += " " + ");\n" + + operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n" + operator_code += " " + " " + "params.params_A0,\n" + operator_code += " " + " " + "params.ref_A0.data(),\n" + operator_code += " " + " " + "params.problem_size_0.mk(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "tb_offset_A0);\n" + + operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n" + + + for i in range (self.b2bnum): + operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n" + operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset" + ");\n" + operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n" + operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n" + + + operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n" + operator_code += " " + "int lane_idx = threadIdx.x % 32;\n" + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n" + + operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n" + + operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n" + operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n" + + operator_code += " " + "src_accum.clear();\n" + operator_code += " " + "accumulators.clear();\n" + operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, " + + for i in range(self.b2bnum): + operator_code += helper.var_idx("iterator_B", i) + ", " + + operator_code += "src_accum" + if self.b2bnum != 1: + operator_code += ", " + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("output_op_", i) + ", " + + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("epilogue_", i) + ", " + + for i in range(self.b2bnum - 1): + final = ", " + if i == self.b2bnum - 2: + final ="" + operator_code += helper.var_idx("iterator_C", i) + final + operator_code += ");\n" + + operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n" + operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + + + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n" + + operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n" + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n" + + + operator_code += " " + "Epilogue epilogue(\n" + operator_code += " " + " " + "shared_storage.epilogue,\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "warp_idx,\n" + operator_code += " " + " " + "lane_idx\n" + operator_code += " " + ");\n" + + operator_code += " " + "epilogue(" + operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", " + operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", " + operator_code += "accumulators, " + operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n" + operator_code += "}\n" + + return ctr_code + operator_code + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\" +#include \"{cutlass_dir}cutlass/semaphore.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + def gen_code(self): + + template_param = [] + template_param.append(("typename", "B2bMma")) + template_param.append(("typename", "Epilogue")) + template_param.append(("typename", "ThreadblockSwizzle")) + template_param.append((bool, "SplitKSerial")) + + code_body = "" + code_body += self.gen_using() + code_body += self.gen_operator_and_constr() + + struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body) + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code))) + + return self.gen_include_header() + code + + + +class gen_kernel: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.template_param = template_param + + self.gen_class_name = "B2bGemm" + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + # Include gen_threadBlock + self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root) + + self.file_dir = output_dir + "/kernel/" + + def gen_code(self, first_use_1stage): + + default_b2b_gemm = self.gen_default_b2b_gemm.gen_code() + + print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_gemm.h", "w+") as f: + f.write(default_b2b_gemm) + + kernel = self.gen_Kerenl.gen_code() + print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_gemm.h", "w+") as f: + f.write(kernel) + + # Call code to gen threadblock + self.gen_threadBlock.gen_code(first_use_1stage) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py new file mode 100644 index 0000000000000000000000000000000000000000..75c2b6869364a6d6f73215e470080ce307f4df55 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_test: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = user_header_file + self.sample_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def gen_cpp_sample(self): + code = "/* Auto Generated code - Do not edit.*/\n" + code += "#include \n" + + code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" + code += "#include \"cutlass/cutlass.h\" \n" + + code += "#include \"../cutlass_irrelevant.h\" \n" + code += "#include \"../cutlass_verify.h\" \n" + + code += "#include \"leaky_bias.h\" \n" + + code += "#include \"utils.h\" \n" + + + + code += "int main(int args, char * argv[]) {\n" + code += " " + "int M = atoi(argv[1]);\n" + code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n" + code += " " + "if(args == 3);\n" + code += " " + " " + "K0 = atoi(argv[2]);\n" + code += " " + "int B = 1;\n" + code += " " + "if(args == 4);\n" + code += " " + " " + "B = atoi(argv[3]);\n" + + code += " " + "srand(1234UL);\n" + code += " " + "int device_id = 0;\n" + code += " " + "cudaGetDevice(&device_id);\n" + code += " " + "cudaDeviceProp prop;\n" + code += " " + "cudaGetDeviceProperties(&prop, device_id);\n" + code += " " + "int sm = prop.major *10 + prop.minor;\n" + code += "using ElementCompute = cutlass::half_t;\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n" + addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + if addbias: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n" + else: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n" + + code += " " + "size_t flops = 0;\n" + + for i in range(self.b2b_num): + m = self.fuse_gemm_info[i]['mnk'][0] + n = self.fuse_gemm_info[i]['mnk'][1] + k = self.fuse_gemm_info[i]['mnk'][2] + + bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i]) + + this_k = "K0" + if (i > 0): + this_k = str(k) + + code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n" + + code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n" + + code += " " + helper.var_idx("memory_unit Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n" + code += " " + helper.var_idx("memory_unit Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n" + + code += " " + helper.var_idx("Mat_A", i) + ".init();\n" + code += " " + helper.var_idx("Mat_B", i) + ".init();\n" + code += " " + helper.var_idx("Mat_C", i) + ".init();\n" + + + + code += " " + helper.var_idx("memory_unit Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n" + + params = [] + params.append("M") + params.append("B") + + params.append("Mat_A0.device_ptr") + for i in range(self.b2b_num): + params.append(helper.var_idx("Mat_B", i) + ".device_ptr") + params.append(helper.var_idx("Mat_C", i) + ".device_ptr") + if i != self.b2b_num-1: + params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr") + params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr") + + code += " " + "Param arguments = {\n" + code += " " + " " + "M,\n" + code += " " + " " + "K0,\n" + code += " " + " " + "B,\n" + + code += " " + " " + "reinterpret_cast(Mat_A0.device_ptr),\n" + cnt = 1 + for i in range(self.b2b_num): + bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n" + cnt += 1 + if bias_flag: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n" + cnt += 1 + else: + code += " " + " " + "reinterpret_cast(NULL),\n" + + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_value = str(arg[2]) + + code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n" + + if i != self.b2b_num - 1: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n" + else: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n" + + + + + code += " " + "TI(FUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + "one_api(arguments, sm, NULL);\n" + + code += " " + "}\n" + code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n" + + code += "\n" + + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmA = "K0" + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_value = str(epilogue_arg[2]) + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")" + code_this += " " + " },\n" + code_this += " " + " " + "B};\n" + + code += code_this + + + + code += " " + "TI(UNFUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + self.gen_class_name + "_verify(\n" + for i in range(self.b2b_num): + code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n" + code += " " + " " + " " + "NULL);\n" + + code += " " + "}\n" + code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n" + + code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \ + + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n" + + code += "\n\n}\n" + + with open(self.sample_dir + "sample.cu", "w+") as f: + f.write(code) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py new file mode 100644 index 0000000000000000000000000000000000000000..6dab42b29f499b85a6d7052c41442ad71ba25925 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py @@ -0,0 +1,1013 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper + + +class gen_default_b2b_mma: + def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root): + self.gen_class_name = "DefaultB2bMma" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/arch/arch.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" + +#include \"../threadblock/b2b_mma_pipelined.h\" +#include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\" +#include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\" +#include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + + def gen_using_MmaCore(self, stage): + threadBlockShape = "ThreadblockShape" + warpShape = "WarpShape" + instrunctionShape = "InstructionShape" + Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore" + + + gen_code = "" + + for i in range(self.b2b_num): + code_using = "using MmaCore" + str(i) + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \ + helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \ + "ElementA", "LayoutA", \ + helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \ + helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \ + "OperatorClass", str(stage), "Operator") + return gen_code + + def gen_using_FusedAddBiasEpilogue(self): + gen_code = "" + for i in range(self.b2b_num - 1): + code_using = helper.var_idx("using FusedAddBiasEpilogue", i) + epilogue_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp" + template_args = helper.var_idx("::Epilogue" + + gen_code += code_using + " = " + epilogue_name + template_args + ";\n" + + return gen_code + + + def gen_using_Iterator(self): + code_using = "using IteratorA0" + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore0" + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapA" + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_") + + for i in range(self.b2b_num): + code_using = "using IteratorB" + str(i) + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore" + str(i) + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapB" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_") + + return gen_code + + def gen_fragment_iterator(self): + gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n" + + for i in range(1, self.b2b_num): + code_using = "using FragmentIteratorA" + str(i) + iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator" + curr_MmaCore = "MmaCore" + str(i) + prev_MmaCore = "MmaCore" + str(i - 1) + Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>" + Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>" + Curr_shape_kK = curr_MmaCore + "::Shape::kK" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \ + helper.var_idx("ElementAccumulator", i-1), "ElementA", \ + "AccumulatorLayout", "InstructionShape_", "true") + + return gen_code + + def gen_threadblockmma(self): + code_using = "using ThreadblockB2bMma" + iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined" + + MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape" + MmaPipelined_param_Mma0_iteratorA = "IteratorA0" + MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA" + MmaPipelined_param_Mma0_iteratorB = "IteratorB0" + MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB" + + MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", " + + for i in range(1, self.b2b_num): + MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape" + MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i) + MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i) + MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB" + + MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", " + + MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, " + + for i in range(self.b2b_num - 1): + epilogue_name = "EpilogueOutputOp" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num - 1): + epilogue_name = "FusedAddBiasEpilogue" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num): + MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy" + MmaPipelined_param_list += MmaPolicy + ", " + + + cnt = 0 + for i in range(self.b2b_num): + MmaStage = helper.var_idx("Stages", i) + final = ", " + if cnt == self.b2b_num - 1: + final = "" + MmaPipelined_param_list += MmaStage + final + cnt += 1 + + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list) + + return gen_code + + + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False) + + # Generate specialized template struct + + mmacore_codebody = self.gen_using_MmaCore(2) + iterator_codebody = self.gen_using_Iterator() + fragment_iterator_codebody = self.gen_fragment_iterator() + epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilogue() + threadBlockMma = self.gen_threadblockmma() + specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma + + # Specialize layout C -> cutlass::layout::RowMajor + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True) + + gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False) + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code))) + + return self.gen_include_header() + code + + +class gen_b2b_mme_pipelined: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bMmaPipelined" + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/array.h\" +#include \"{cutlass_dir}cutlass/aligned_buffer.h\" +#include \"{cutlass_dir}cutlass/numeric_conversion.h\" + +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/matrix_shape.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\" + +#include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root) + return code + + + def gen_using(self): + code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n" + + code_using += "using Base = B2bMmaBase<" + for i in range(self.b2b_num): + code_using += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Stage", i) + "_, " + code_using = code_using[: -2] + ">;\n" + + + for i in range(self.b2b_num): + code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n" + code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n" + code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n" + + for i in range(self.b2b_num - 1): + code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n" + + code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n" + code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n" + + code_using += "private:\n" + code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n" + code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n" + + for i in range(1, self.b2b_num): + code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n" + code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n" + + code_using += "protected:\n" + + code_using += "SmemIteratorA0 smem_iterator_A_;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n" + + return code_using + + + def gen_operator(self, first_use_1stage = False): + code = "" + def gen_operator_param(b2b_num): + param_code = "" + param_code += "int gemm_k_iterations_0,\n" + param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n" + param_code += "IteratorA0 iterator_A,\n" + + for i in range(b2b_num): + param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n" + + param_code += "FragmentC0 const &src_accum, \n" + + for i in range(b2b_num - 1): + param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n" + + + param_code += "TransformA0 transform_A0 = TransformA0(), \n" + + for i in range(b2b_num): + final = "(),\n" + if i == b2b_num - 1: + final = "()\n" + param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final + + return param_code + + + + def gen_first_gemm_1stage(b2b_num): + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + WarpFragmentA0 warp_frag_A0;\n\ + WarpFragmentB0 warp_frag_B0;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\ + }\n\ + this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\ +\n\ + __syncthreads();\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + if(gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n" + + return accu_code + code + + + def gen_first_gemm_2stage(b2b_num): + + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + ++this->smem_iterator_A_;\n\ + ++this->smem_iterator_B0_;\n\ +\n\ + __syncthreads();\n\ +\n\ + // Pair of fragments used to overlap shared memory loads and math instructions\n\ + WarpFragmentA0 warp_frag_A0[2];\n\ + WarpFragmentB0 warp_frag_B0[2];\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\ +\n\ + // Write fragments to shared memory\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ +\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ + \n\ + ++this->smem_iterator_B0_;\n\ + ++this->smem_iterator_A_;\n\ + \n\ +\n\ + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\ + if (smem_write_stage_idx == 1) {\n\ + this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\ + this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\ + }\n\ + else {\n\ + this->warp_tile_iterator_A0_.add_tile_offset(\n\ + {0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset(\n\ + {-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\ + 0});\n\ + }\n\ +\n\ + smem_write_stage_idx ^= 1;\n\ + }\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + \n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + if (warp_mma_k == 0) {\n\ +\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + // Avoid reading out of bounds if this was the last loop iteration\n\ + if (gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n\ +\n\ + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\ + }\n\ + }\n" + return accu_code + code + + def gen_other_gemms_2stage(b2b_num): + + code = "" + + def gemm_teamplate(id): + code = "// " + str(id + 1) + " Gemm" + code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n" + + code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilogue_accu", id - 1) + ";\n" + code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \ + + helper.var_idx(", after_epilogue_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n" + + # FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilogue_accu", id - 1) + ");\n" + # FragmentB1 tb_frag_B1; + code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n" + # tb_frag_B1.clear(); + code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n" + # iterator_B1.load(tb_frag_B1); + code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++iterator_B1; + code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++this->smem_iterator_B1_; + code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n" + # __syncthreads(); + code += " " + "__syncthreads();\n" + # WarpFragmentA1 warp_frag_A1[2]; + code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n" + # WarpFragmentB1 warp_frag_B1[2]; + code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n" + # this->warp_tile_iterator_B1_.set_kgroup_index(0); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n" + # ++warp_tile_iterator_A1_; + code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # Operator1 warp_mma1; + code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n" + # smem_write_stage_idx = 1; + code += " " + "smem_write_stage_idx = 1;\n" + # int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n" + # if (gemm_k_iterations_1 <= 1) { + # iterator_B1.clear_mask(); + # } + code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \ + + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " +"}\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n" + # if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n" + # __syncthreads(); + code += " " + " " + " " + " " + "__syncthreads();\n" + # ++smem_iterator_B1_; + code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n" + # if (smem_write_stage_idx == 1) { + # smem_iterator_B1_.add_tile_offset({-Base::Stage, 0}); + # } + code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \ + + " " + " " + " " + " " +"}\n" + # else { + # this->warp_tile_iterator_B1_.add_tile_offset( + # {-Base::Stage * Policy1::kPartitionsK * + # Base::kWarpGemmIterations1, + # 0}); + # } + code += " " + " " + " " + " " + "else {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \ + + " " + " " + " " + " " + " " + "0});\n" \ + + " " + " " + " " + " " + "}\n" + + # smem_write_stage_idx ^= 1; + # } + code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \ + + " " + " " + " " + "}\n" + + # this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n" + # ++warp_tile_iterator_A1_; + code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # if (warp_mma_k == 0) { + # iterator_B1.load(tb_frag_B1); + # ++iterator_B1; + # if (gemm_k_iterations_1 <= 2) { + # iterator_B1.clear_mask(); + # } + # } + code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \ + + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \ + + " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \ + + " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " + " " + " " + " " + "}\n" \ + + " " + " " + " " + "}\n" + # warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum); + # } + # } + code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \ + + " " + " " + "}\n" \ + + " " + "}\n\n\n" + + return code + + for i in range (1, b2b_num): + clear_accu = "" + if i != b2b_num - 1: + clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n" + clear_accu += " " + helper.var_idx("accum", i) +".clear();\n" + code += clear_accu + gemm_teamplate(i) + + return code + + operator_code = " CUTLASS_DEVICE\n\ + void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n" + if first_use_1stage: + operator_code += gen_first_gemm_1stage(self.b2b_num) + else: + operator_code += gen_first_gemm_2stage(self.b2b_num) + operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n" + return operator_code + + def gen_construct_func(self): + name = self.gen_class_name + func_code = "CUTLASS_DEVICE\n" + func_code += name + "(\n" \ + + " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \ + + " " + "int thread_idx,\n" \ + + " " + "int warp_idx,\n" \ + + " " + "int lane_idx\n" \ + + "):\n" + func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \ + + " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n" + + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num - 1: + final = " {\n" + func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final + + func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + + func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n" + func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n" + + func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n" + + func_code += "}\n" + + return func_code + + def gen_member_func(self, first_use_1stage): + code = "public:\n" + code += self.gen_operator(first_use_1stage) + code += self.gen_construct_func() + + return code + + def gen_code(self, first_use_1stage): + + def gen_template_args(b2b_num): + template_param = [] + template_param.append(("typename", "Shape0")) + template_param.append(("typename", "IteratorA0")) + template_param.append(("typename", "SmemIteratorA0")) + template_param.append(("typename", "IteratorB0")) + template_param.append(("typename", "SmemIteratorB0")) + + for i in range(1, b2b_num): + template_param.append(("typename", helper.var_idx("Shape", i))) + template_param.append(("typename", helper.var_idx("FragmentIteratorA", i))) + template_param.append(("typename", helper.var_idx("IteratorB", i))) + template_param.append(("typename", helper.var_idx("SmemIteratorB", i))) + + template_param.append(("typename", "ElementC")) + template_param.append(("typename", "LayoutC")) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("OutputOp", i))) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i))) + + for i in range(0, b2b_num): + template_param.append(("typename", helper.var_idx("Policy", i))) + for i in range(0, b2b_num): + template_param.append((int, helper.var_idx("Stage", i))) + + template_param.append(("typename","TransformA0", "NumericArrayConverter")) + + for i in range(0, b2b_num): + cvtr = helper.var_idx("NumericArrayConverter" + template_param.append(("typename", helper.var_idx("TransformB", i), cvtr)) + + template_param.append(("typename", "Enable", "bool")) + + return template_param + + template_param = gen_template_args(self.b2b_num) + inheritance_code = "public B2bMmaBase<" + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num - 1): + inheritance_code += helper.var_idx("Stage", i) + "_, " + inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_" + inheritance_code += ">" + + code_body = "" + using_code= self.gen_using() + func_code = self.gen_member_func(first_use_1stage) + + code_body = using_code + func_code + + class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code) + + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + # print(code) + return code + + +class gen_b2b_mma_base: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dirs}cutlass/aligned_buffer.h\" +#include \"{cutlass_dirs}cutlass/arch/memory.h\" +#include \"{cutlass_dirs}cutlass/array.h\" +#include \"{cutlass_dirs}cutlass/cutlass.h\" +#include \"{cutlass_dirs}cutlass/gemm/gemm.h\" +#include \"{cutlass_dirs}cutlass/matrix_shape.h\" +#include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root) + return code + + def gen_shared_storage(self): + code = \ +" template< \n\ + typename Shape_,\n\ + typename Policy_,\n\ + int ThisStage_\n\ +>\n\ +class SharedStorage {\n\ +public:\n\ + using Shape = Shape_;\n\ + using Policy = Policy_;\n\ + static int const ThisStage = ThisStage_;\n\ + using Operator = typename Policy::Operator;\n\ + \ + using TensorRefA = TensorRef;\n\ + \ + /// Tensor reference to the B operand \n\ + using TensorRefB = TensorRef;\n\ +\n\ + /// Shape of the A matrix operand in shared memory \n\ + using ShapeA = MatrixShape;\n\ +\n\ + /// Shape of the B matrix operand in shared memory\n\ + using ShapeB =\n\ + MatrixShape;\n\ +\n\ + public:\n\ +\n\ + /// Buffer for A operand\n\ + AlignedBuffer operand_A;\n\ +\n\ + /// Buffer for B operand\n\ + AlignedBuffer operand_B;\n\ +\n\ + public:\n\ +\n\ + /// Returns a layout object for the A matrix\n\ + CUTLASS_DEVICE\n\ + static typename Operator::LayoutA LayoutA() {\n\ + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\ + }\n\ +\n\ + /// Returns a layout object for the B matrix\n\ + CUTLASS_HOST_DEVICE\n\ + static typename Operator::LayoutB LayoutB() {\n\ + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\ + }\n\ +\n\ + /// Returns a TensorRef to the A operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefA operand_A_ref() {\n\ + return TensorRefA{operand_A.data(), LayoutA()};\n\ + }\n\ +\n\ + /// Returns a TensorRef to the B operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefB operand_B_ref() {\n\ + return TensorRefB{operand_B.data(), LayoutB()};\n\ + }\n\ + CUTLASS_HOST_DEVICE\n\ + void * get_B_Shared_ptr() {\n\ + return operand_B.data();\n\ + }\n\ + };\n" + return code + + def gen_using_and_misc(self, b2b_num): + code_using = "" + for i in range(b2b_num): + code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n" + + for i in range(b2b_num): + code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n" + + for i in range(b2b_num): + code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\ + + helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\ + + helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n" + + code_misc = "" + for i in range(b2b_num): + code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n" + + code = code_using + code_misc + self.gen_shared_storage() + + for i in range(b2b_num): + code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n" + + def gen_union_shared_storage(b2b_num): + code = "" + for i in range(b2b_num): + code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n" + return code + + code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n" + + for i in range(b2b_num - 1): + code += helper.var_idx("void * C", i) + "_smm_ptr;\n" + + return code + + def gen_protected(self): + code = "\nprotected:\n" + code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n" + for i in range(self.b2b_num): + code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n" + return code + + def gen_public_member(self): + code = "\npublic:\n" + + code += "CUTLASS_DEVICE\n" + code += \ + "B2bMmaBase(\n" + \ + " B2bMmaSharedStorage & shared_storage,\n" + \ + " int thread_idx,\n" + \ + " int warp_idx,\n" + \ + " int lane_idx\n" + \ + "):\n" + \ + " warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n" + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num-1: + final = "\n" + + iterator = " warp_tile_iterator_B" + str(i) + "_" + shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()" + code += iterator + "(" + shared_storage + ", lane_idx)" + final + + + code += "{\n" + for i in range(self.b2b_num - 1): + code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n" + code += "}\n" + + return code + + def gen_code(self): + + template_arg = [] + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Shape", i))) + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Policy", i))) + for i in range(self.b2b_num): + template_arg.append((int, helper.var_idx("Stage", i))) + + + + code_body = self.gen_using_and_misc(self.b2b_num) + code_body += self.gen_protected() + code_body += self.gen_public_member() + + class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body) + + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + + return code + + +class gen_threadblock: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.file_dir = output_dir + "/threadblock/" + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_b2b_mma_pipelined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + + def gen_code(self, first_use_1stage): + + base_code = self.gen_b2b_mma_base.gen_code() + print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_base.h", "w+") as f: + f.write(base_code) + pipeline_code = self.gen_b2b_mma_pipelined.gen_code(first_use_1stage = first_use_1stage) + print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f: + f.write(pipeline_code) + default_code = self.gen_default_b2b_mma.gen_code() + print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_mma.h", "w+") as f: + f.write(default_code) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py new file mode 100644 index 0000000000000000000000000000000000000000..8697ca88ee4cb40056cca06b8e028abd6b491749 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -0,0 +1,456 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_turing_impl: + def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.class_name = gen_class_name + self.gen_class_name = gen_class_name + "_turing_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_using(self): + code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + ";" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + + if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code += code_this + code += "typename b2b_gemm::Arguments arguments{\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("problem_size_", i) + ",\n" + + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n" + + for i in range(self.b2b_num): + + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n" + + + for i in range(self.b2b_num): + code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code += "},\n" + code += " " + "Batch};\n\n" + + code += " " "b2b_gemm gemm_op;\n" + code += " " + "gemm_op.initialize(arguments);\n" + return code + "\n" + + + + def gen_run(self): + code = " " + "gemm_op(stream);\n" + + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + + if self.b2b_num == 1: + code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta + code_body += self.gen_turing_unfused.gen_initialize() + code_body += self.gen_turing_unfused.gen_run() + else: + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + + code = self.gen_wrapper() + helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_volta_turing_fuse_act_impl: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + "_volta_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def perf_tiling(self, layer_mnk): + mnk = layer_mnk[:] + block_tile = mnk[:] + block_tile[2] = 32 # force the K tile to be 32 + + # M tile gen + block_tile[0] = 32 + + # N tile gen + if mnk[1] > 128: + block_tile[1] = 256 + elif mnk[1] > 64: + block_tile[1] = 128 + elif mnk[1] > 32: + block_tile[1] = 64 + else : + block_tile[1] = 32 + + warp_tile = block_tile[:] + if block_tile[1] == 256: + warp_tile[1] = 64 + elif block_tile[1] == 128: + warp_tile[1] = 32 + elif block_tile[1] == 64: + warp_tile[1] = 32 + else : + warp_tile[1] = 32 + + warp_tile[0] = 32 + + return block_tile, warp_tile + + + def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp): + epilogue_setted_type = epilogue_tp + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">" + + return epilogue_str + + def gen_using(self, volta = True): + code_using = "" + volta_arch = "cutlass::arch::Sm70" + volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>" + + turing_arch = "cutlass::arch::Sm75" + turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>" + + arch = "" + tc = "" + if volta: + arch = volta_arch + tc = volta_tc + else: + arch = turing_arch + tc = turing_tc + + for i in range(self.b2b_num): + + k = self.fuse_gemm_info[i]['mnk'][2] + + k_mod_8 = k % 4 + ab_ldm = 1 + if k_mod_8 == 0: + ab_ldm = 8 + elif k_mod_8 == 4: + ab_ldm = 4 + elif k_mod_8 == 2 or k_mod_8 == 6: + ab_ldm = 2 + + block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk']) + + this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n" + this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n" + this_gemm_config += " " + arch + ",\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n" + this_gemm_config += " " + tc + ",\n" + this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n" + this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n" + this_gemm_config += " " + "2,\n" + this_gemm_config += " " + str(ab_ldm) + ",\n" + this_gemm_config += " " + str(ab_ldm) + ">;\n" + + code_using += this_gemm_config + "\n" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = k_str + ldmB = k_str + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code_this += " },\n" + code_this += " " + "Batch};\n" + + code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n" + + code += code_this + "\n" + return code + "\n" + + + def gen_run(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n" + + code += code_this + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + code = self.gen_wrapper() + helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_one_API: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_CUTLASS_irrelevant_API(self): + code = "" + code += "#include \n" + code += "#include \n" + + param_name = "Fused" + str(self.b2b_num) + "xGemm_" + for i in range(self.b2b_num): + param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_" + param_name += "Params" + params = "" + params += " " + "int M;\n" + params += " " + "int K0;\n" + params += " " + "int Batch;\n" + params += " " + "const void* A0;\n" + for i in range(self.b2b_num): + params += " " + "const void* " + helper.var_idx("B", i) + ";\n" + params += " " + "const void* " + helper.var_idx("C", i) + ";\n" + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + params += " " + arg_tp + " " + arg_name + ";\n" + params += " " + "void* " + helper.var_idx("D", i) + ";\n" + code += ir.gen_struct(param_name, params) + code += "using Param = " + param_name + ";\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n" + + + return code + + def gen_one_api(self): + code = "" + code += "/* Auto Generated code - Do not edit.*/\n" + code += "#include \"cutlass_irrelevant.h\"\n" + code += "#include \"api.h\"\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n" + + code += " " + "if (sm == 70) \n" + code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else if(sm >= 75) \n" + code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else assert(0);\n" + code += "}\n" + return code + + def gen_code(self): + + turing_code = self.gen_turing.gen_wrapper() + volta_code = self.gen_volta.gen_wrapper() + cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API() + + one_api_code = self.gen_one_api() + with open(self.output_dir + "one_api.cu", "w+") as f: + f.write(one_api_code) + + helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code) + + helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py new file mode 100644 index 0000000000000000000000000000000000000000..5bd465162b60954061ab7cd341281025dd55c0d0 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +import gen_turing_and_volta as gen_basic + + +class gen_verify: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.name = gen_class_name + "_verify" + self.b2b_num = len(fuse_gemm_info) + self.params = [] + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.separate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + self.gen_params() + self.output_dir = output_dir + + + def gen_code(self): + code = "" + code += self.user_header_file + code += self.separate_cutlass.gen_using(False) #False -> Turing, True -> Volta + + code_body = "" + for i in range(self.b2b_num): + code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n" + + code_body += self.separate_cutlass.gen_run() + + code += ir.gen_func(self.name, self.params, code_body) + helper.write_2_headfile("cutlass_verify.h", self.output_dir, code) + + + def gen_params(self): + for i in range(self.b2b_num): + self.params.append( + ( + helper.var_idx("typename Gemm", i)+ "::Arguments", + helper.var_idx("Arguments_", i) + ) + ) + + + def get_params(self, declaration = True): + code = "" + if declaration: + for param in self.params: + code += param[0] + " " + param[1] + ";\n" + + return code + + + def gen_initialize(): + code = "" + initialize_code = self.separate_cutlass.gen_initialize() + + code = ir.gen_func("initialize", [[]]) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..f757d9522476d992a6df0e55cfc961c66b92c19e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py @@ -0,0 +1,135 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +def type_2_cutlass_type(input_type = "fp16"): + # float point type + if input_type == "fp32": + return "float" + if input_type == "bf16": + return "cutlass::bfloat16_t" + if input_type == "fp16": + return "cutlass::half_t" + + # integer type + if(input_type == "int32"): + return "int32_t" + if(input_type == "int8"): + return "int8_t" + + if input_type == 'Row': + return 'cutlass::layout::RowMajor' + if input_type == 'Col': + return 'cutlass::layout::ColumnMajor' + +def cvt_2_cutlass_shape(gemm_shape): + # gemm shape + if len(gemm_shape) == 3: + val = "cutlass::gemm::GemmShape<" \ + + str(gemm_shape[0]) + ", " \ + + str(gemm_shape[1]) + ", " \ + + str(gemm_shape[2]) + ">" + return val + + +def write_2_headfile(filename, file_dir, string): + with open(file_dir + filename, 'w') as f: + f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string) + +def var_idx(variable, index): + return variable + str(index) + + +def list_2_string(input_list, ): + rtn_string = "" + + cnt = 0 + + for element in input_list: + final = ", \n" + if cnt == len(input_list) - 1: + final = "\n" + cnt += 1 + rtn_string += str(element) + final + + return rtn_string + + +def get_epilogue_info(layer_info): + return layer_info['epilogue'] + +def get_epilogue_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['tp'] + +def get_epilogue_add_bias_or_not(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['addbias'] + +def get_epilogue_add_bias_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['bias_tp'] + +def get_epilogue_args(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['args'] + +def get_epilogue_bias_shape(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + if bias_tp == 'mat': + mn_shape[0] = 'M' + return mn_shape + elif bias_tp == 'vec': + mn_shape[0] = 1 + return mn_shape + else: + assert(0) + +def get_epilogue_bias_ldm(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + c_layout = layer_info['C_format'].lower() + + if c_layout != 'row': + assert(0) + + if bias_tp == 'mat': + return mn_shape[1] + elif bias_tp == 'vec': + return 0 + else: + assert(0) + +def get_epilogue_compute_tp(layer_info): + return layer_info['Acc_tp'] diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e79efc05f502bb95a10d1efc294453a7300db9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py @@ -0,0 +1,67 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os + +class replace_fix_impl: + def __init__(self, src_dir, dst_dir, cutlass_deps_root): + self.src_dir = src_dir + self.dst_dir = dst_dir + self.cutlass_deps_root = cutlass_deps_root + + + + def gen_code(self): + for sub_dir in os.walk(self.src_dir): + files_in_sub_dir = sub_dir[2] + + src_dirs = sub_dir[0] + output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):] + + if not os.path.exists(output_dirs): + os.mkdir(output_dirs) + + for f in files_in_sub_dir: + with open(src_dirs +"/" + f, 'r') as current_file: + output_lines = [] + lines = current_file.readlines() + + for line in lines: + if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"): + new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):] + # print(new_line) + output_lines.append(new_line) + else: + output_lines.append(line) + + with open(output_dirs + "/" + f, "w+") as dest_file: + dest_file.writelines(output_lines) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..cc34a412a14aa363aa5a614ed8bd378b80d54568 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +template +__device__ +T add(T const & a, T const &b){ + return (a + b); +} + +template <> +__device__ +half2 add(half2 const & a, half2 const &b){ + return (__hadd2(a,b)); +} + +template +struct RELU{ + __device__ + T operator()(T const & a){ + return a > T(0) ? a : T(0); + } + __device__ + half2 operator()(half2 const & a){ + float2 a_fp32x2 = __half22float2(a); + a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; + a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; + if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) + printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); + return __float22half2_rn(a_fp32x2); + } +}; + +template +struct LEAKY_RELU{ + __device__ + T operator()(T const & a, T const & scale = half(1)){ + return a > T(0) ? a : scale * a; + } + __device__ + half2 operator()(half2 const & a, half const & scale = half(1)){ + half2 zero = __half2half2(half(0)); + half2 gt_zero = __hge2(a, zero); + half2 le_zero = __hle2(a, zero); + + + half2 scale_f16x2 = __half2half2(scale); + half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); + return __hmul2(a, mask_scale_f16x2); + } +}; + +template +__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + LEAKY_RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); + } + + } +} + + + +template +__global__ void leaky_and_activation(half* inout, half scale){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + LEAKY_RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); + } + + } +} + + + +template +void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ + + dim3 grid(m, b); + if (bias == nullptr) + leaky_and_activation<<>>(inout, scale); + else + leaky_and_activation<<>>(inout, bias, scale, mat_bias); +} + +template +__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); + } + + } +} + + + +template +__global__ void relu_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); + } + + } +} + + + +template +void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + relu_and_activation<<>>(inout); + else + relu_and_activation<<>>(inout, bias, mat_bias); +} + + +template +__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); + } + + } +} + +template +__global__ void identity_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); + } + + } +} + +template +void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + identity_and_activation<<>>(inout); + else + identity_and_activation<<>>(inout, bias, mat_bias); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..a825033f3bb9d195daec3c7691d100693ed1ce44 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/44_multi_gemm_ir_and_codegen/utils.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#define TI(tag) \ + cudaEvent_t _event_start_ ##tag; \ + cudaEvent_t _event_end_ ##tag; \ + float _event_time_ ##tag; \ + cudaEventCreate(& _event_start_ ##tag); \ + cudaEventCreate(& _event_end_ ##tag); \ + cudaEventRecord(_event_start_ ##tag); + +#define TO(tag, str, times) \ + cudaEventRecord(_event_end_ ##tag); \ + cudaEventSynchronize(_event_end_ ##tag); \ + cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ + float _event_time_once_ ##tag = _event_time_ ##tag / times; \ + printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ + cudaDeviceSynchronize(); \ + printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); + +template +struct memory_unit{ + T* host_ptr; + T* device_ptr; + int size_bytes; + int elements; + void h2d(){ + cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); + } + void d2h(){ + cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); + } + void free_all(){ + free(host_ptr); + cudaFree(device_ptr); + } + memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ + host_ptr = (T*) malloc(elements_ * sizeof(T)); + cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); + } + void init(int abs_range = 1){ + for(int i = 0; i < elements; i++){ + host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); + } + h2d(); + } +}; + +template +int check_result(T * a, T * b, int N){ + int cnt = 0; + for(int i = 0; i < N; i ++){ + float std = float(a[i]); + float my = float(b[i]); + + if(abs(std - my) / abs(std) > 1e-2) + { + // printf("my: %f , std: %f\n", my, std); + cnt++; + } + + } + printf("total err: %d / %d\n", cnt, N); + return cnt; +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..c6b6f30d30db3b7f2aed1463373ad7d55b87eb94 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/device/dual_gemm.h @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Performs a dual gemm in one fused kernel: +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "../kernel/dual_gemm.h" +#include "../dual_gemm_common.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B0 matrix operand + typename LayoutB0_, + /// Layout type for B1 matrix operand + typename LayoutB1_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + typename EpilogueOutputOp2_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + bool StoreD0 = true, + bool StoreD1 = true, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel + /// Define the threadblock-scoped matrix multiply-accumulate + static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); + static_assert(kStages >= 3, "Only multistage is implemented"); + using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using DualMma = threadblock::DualMmaMultistage< + typename Mma0::Shape, + typename Mma0::IteratorA, + typename Mma0::SmemIteratorA, + Mma0::kCacheOpA, + typename Mma0::IteratorB, + typename Mma0::SmemIteratorB, + Mma0::kCacheOpB, + typename Mma1::IteratorB, + typename Mma1::SmemIteratorB, + typename Mma0::ElementC, + typename Mma0::LayoutC, + typename Mma0::Policy, + typename Mma1::Policy, + Mma0::kStages, + SharedMemoryClearOption::kNone + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue0 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, + EpilogueOutputOp0::kCount>::Epilogue; + using Epilogue1 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using DualGemmKernel = kernel::DualGemm< + DualMma, + Epilogue0, Epilogue1, EpilogueOutputOp2, + ThreadblockSwizzle, kSplitKSerial, + kStoreD0, kStoreD1>; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + DualGemmMode mode; + GemmCoord problem_size; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_D0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + TensorRef ref_D2; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + int split_k_slices; + + int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + DualGemmMode mode, + GemmCoord problem_size_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_D0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + TensorRef ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = + typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0 + ): + mode(mode), + problem_size(problem_size_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_D0(ref_D0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + ref_D2(ref_D2_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + epilogue2(epilogue2_), + split_k_slices(split_k_slices_), + batch_count(batch_count), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + } + }; + +private: + + /// Kernel parameters object + typename DualGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + DualGemm() = default; + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { + return Status::kErrorInvalidProblem; + } + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (kStoreD0 != (args.ref_D0.data() != nullptr)) { + return Status::kErrorInternal; + } + if (kStoreD1 != (args.ref_D1.data() != nullptr)) { + return Status::kErrorInternal; + } + + Status status = DualGemmKernel::can_implement( + args.problem_size, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + if (kSplitKSerial && args.split_k_slices > 1) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename DualGemmKernel::Params{ + args.mode, + args.problem_size, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2, + args.epilogue0, + args.epilogue1, + args.epilogue2, + reinterpret_cast(workspace), + args.batch_stride_A, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C, + args.batch_stride_D, + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_D0.reset(args.ref_D0.data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.ref_D2.reset(args.ref_D2.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.output_op_2 = args.epilogue2; + params_.semaphore = reinterpret_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(DualGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h new file mode 100644 index 0000000000000000000000000000000000000000..12ebe05ea760c97dcdc895266bc82ccb97b5abbf --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_common.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h new file mode 100644 index 0000000000000000000000000000000000000000..d042f93cac7f9830c2ce780c482c7034c64b8ac9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/dual_gemm_run.h @@ -0,0 +1,938 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "dual_gemm_common.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template < + typename OutputOp, + typename Element, + typename Layout> +struct TensorEpilogueForEachFunc { + /// View type + using TensorView = cutlass::TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view_x0; + TensorView view_x1; + TensorView view_y; + OutputOp output_op; + + + // + // Methods + // + + Params( + TensorView view_x0_ = TensorView(), + TensorView view_x1_ = TensorView(), + TensorView view_y_ = TensorView(), + OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) + ): + view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { + } + }; + + Params params; + + CUTLASS_DEVICE + TensorEpilogueForEachFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + Element const & x0 = params.view_x0.at(coord); + Element const & x1 = params.view_x1.at(coord); + Element& y = params.view_y.at(coord); + y = params.output_op(x0, x1); + } +}; + +template < + typename OutputOp, + typename Element, + typename Layout> +void TensorEpilogueForEach( + cutlass::TensorView x0, + cutlass::TensorView x1, + cutlass::TensorView y) { + + using Func = TensorEpilogueForEachFunc; + using Params = typename Func::Params; + + cutlass::reference::device::TensorForEach( + y.extent(), + Params(x0, x1, y) + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +template +struct NonFusedDualGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + NonFusedDualGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + reference_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; + typename Gemm0::Arguments arguments_0{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0}, + split_k_slices + }; + + split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; + typename Gemm1::Arguments arguments_1{ + problem_size, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1}, + split_k_slices + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + // Allocate workspace memory + cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); + cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); + + cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1, workspace1.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed0 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed0); + + bool passed1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed1); + if (!passed0 || !passed1) { + + std::stringstream fname; + + fname << "error_DualGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed0 && passed1; + } +}; + +template +struct DualFusedGemmRun +{ + + using DualGemm = DualGemm_; + using ElementAccumulator = typename DualGemm::ElementAccumulator; + using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + DualFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(1), + int batch_count = 1, + bool broadcast_b1 = false, + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename DualGemm::ElementA, + typename DualGemm::LayoutA> tensor_A0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB0> tensor_B0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB1> tensor_B1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + if (broadcast_b1) { + tensor_B1.resize({problem_size.k(), batch_count}); + } + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + tensor_D2.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D2.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D0.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + reference_D2.sync_device(); + + // + // Batch strides (irrelevant when batch_count == 1) + // + + int64_t batch_stride_A = problem_size.m() * problem_size.k(); + int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); + int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); + if (broadcast_b1) { + // B1 is a (column) vector + batch_stride_B1 = problem_size.k(); + } + int64_t batch_stride_Bias = problem_size.n(); + int64_t batch_stride_D = problem_size.m() * problem_size.n(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + typename cutlass::TensorRef nullptr_ref{}; + decltype(nullptr_ref) ref_B0, ref_B1; + if (beta0 != ElementCompute(0)) { + ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + if (beta1 != ElementCompute(0)) { + ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + typename DualGemm::Arguments arguments{ + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm), + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ref_B0, + DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + tensor_B1.device_ref()), + ref_B1, + DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, + tensor_D2.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + {}, + split_k_slices, + batch_count, + batch_stride_A, + batch_stride_B0, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + }; + + // + // Run the GEMM + // + + DualGemm b2b_gemm_op; + + cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + CUTLASS_CHECK(status); + + status = b2b_gemm_op.initialize(arguments, workspace.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + tensor_D2.sync_host(); + + // + // Verify + // + + using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB0, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal0 reference_gemm0; + + typename GemmUniversal0::Arguments args0 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha0, beta0}, + tensor_A0.device_data(), + tensor_B0.device_data(), + tensor_Bias0.device_data(), + reference_D0.device_data(), + batch_stride_A, + batch_stride_B0, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + tensor_B0.stride(0), + 0, // zero stride for the bias vector + reference_D0.stride(0), + }; + + status = reference_gemm0.can_implement(args0); + CUTLASS_CHECK(status); + status = reference_gemm0(args0); + CUTLASS_CHECK(status); + + using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB1, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal1 reference_gemm1; + + typename GemmUniversal1::Arguments args1 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha1, beta1}, + tensor_A0.device_data(), + tensor_B1.device_data(), + tensor_Bias1.device_data(), + reference_D1.device_data(), + batch_stride_A, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + (broadcast_b1 ? 0 : tensor_B1.stride(0)), + 0, // zero stride for the bias vector + reference_D1.stride(0), + }; + + status = reference_gemm1.can_implement(args1); + CUTLASS_CHECK(status); + status = reference_gemm1(args1); + CUTLASS_CHECK(status); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + reference_D2.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); + + bool passed_out0 = true; + if (DualGemm::kStoreD0) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + passed_out0 = cutlass::reference::host::TensorEquals( + reference_D0.host_view(), + tensor_D0.host_view()); + } + CHECK_TRUE(passed_out0); + + bool passed_out1 = true; + if (DualGemm::kStoreD1) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + passed_out1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + } + CHECK_TRUE(passed_out1); + + bool passed_out2 = cutlass::reference::host::TensorEquals( + reference_D2.host_view(), + tensor_D2.host_view()); + CHECK_TRUE(passed_out2); + + bool passed = passed_out0 && passed_out1 && passed_out2; + if (!passed) + { + std::stringstream fname; + + fname << "error_DualGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference0 =\n" << reference_D0.host_view() + << "\nComputed0 =\n" << tensor_D0.host_view() + << "\n\nReference1 =\n" << reference_D1.host_view() + << "\nComputed1 =\n" << tensor_D1.host_view() + << "\n\nReference2 =\n" << reference_D2.host_view() + << "\nComputed2 =\n" << tensor_D2.host_view(); + } + //std::cout << "A0 " << tensor_A0.host_view() << std::endl; + // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; + // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; + //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h new file mode 100644 index 0000000000000000000000000000000000000000..2ff5b5e24f8927fe2c71e3a428203d52298e6e86 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/kernel/dual_gemm.h @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "../threadblock/dual_mma_multistage.h" +#include "../threadblock/dual_epilogue.h" +#include "../dual_gemm_common.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue0_, ///! Epilogue + typename Epilogue1_, ///! Epilogue + typename OutputOp2_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool StoreD0, + bool StoreD1 +> +struct DualGemm { + + using DualMma = DualMma_; + + using Epilogue0 = Epilogue0_; + using Epilogue1 = Epilogue1_; + using OutputOp0 = typename Epilogue0::OutputOp; + using OutputOp1 = typename Epilogue1::OutputOp; + using OutputOp2 = OutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + + using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< + typename Epilogue0::Shape, + typename Epilogue0::WarpMmaOperator, + Epilogue0::kPartitionsK, + typename Epilogue0::OutputTileIterator, + typename Epilogue0::AccumulatorFragmentIterator, + typename Epilogue0::WarpTileIterator, + typename Epilogue0::SharedLoadIterator, + OutputOp0, + OutputOp1, + OutputOp2, + typename Epilogue0::Padding, + kStoreD0, + kStoreD1, + Epilogue0::kFragmentsPerIteration, + true // IterationsUnroll + >; + + using ElementA = typename DualMma::IteratorA::Element; + using ElementB = typename DualMma::IteratorB0::Element; + using ElementC = typename DualEpilogue::OutputTileIterator::Element; + + static bool const kSplitKSerial = SplitKSerial; + static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), + "Split-K serial requires buffers for D0/D1 for reduction"); + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename DualMma::WarpCount; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + DualGemmMode mode; + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Mma0 + typename DualMma::IteratorA::Params params_A0; + typename DualMma::IteratorA::TensorRef ref_A0; + typename DualMma::IteratorB0::Params params_B0; + typename DualMma::IteratorB0::TensorRef ref_B0; + typename Epilogue0::OutputTileIterator::Params params_C0; + typename Epilogue0::OutputTileIterator::TensorRef ref_C0; + typename Epilogue0::OutputTileIterator::Params params_D0; + typename Epilogue0::OutputTileIterator::TensorRef ref_D0; + typename OutputOp0::Params output_op_0; + + // Mma1 + typename DualMma::IteratorB1::Params params_B1; + typename DualMma::IteratorB1::TensorRef ref_B1; + typename Epilogue1::OutputTileIterator::Params params_C1; + typename Epilogue1::OutputTileIterator::TensorRef ref_C1; + typename Epilogue1::OutputTileIterator::Params params_D1; + typename Epilogue1::OutputTileIterator::TensorRef ref_D1; + typename OutputOp1::Params output_op_1; + + typename Epilogue1::OutputTileIterator::Params params_D2; + typename Epilogue1::OutputTileIterator::TensorRef ref_D2; + typename OutputOp2::Params output_op_2; + + int *semaphore; + int gemm_k_size; + + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + DualGemmMode mode, + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + // Mma0: D0 = A @ B0 + C0 + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + // Mma1: D1 = A @ B1 + C1 + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + + typename Epilogue1::OutputTileIterator::TensorRef ref_D2, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), + int *workspace = nullptr, + int64_t batch_stride_A = 1, + int64_t batch_stride_B0 = 1, + int64_t batch_stride_B1 = 1, + int64_t batch_stride_C = 1, + int64_t batch_stride_D = 1 + ): + mode(mode), + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + // Mma0 + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_D0(ref_D0.layout()), + ref_D0(ref_D0), + // Mma1 + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + params_D2(ref_D2.layout()), + ref_D2(ref_D2), + output_op_0(output_op_0), + output_op_1(output_op_1), + output_op_2(output_op_2), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename DualMma::SharedStorage main_loop; + typename DualEpilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DualGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { + + static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D2, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB *ptr_B1 = static_cast(params.ref_B1.data()); + + // + // Fetch pointers based on mode. + // + if (params.mode == DualGemmMode::kGemm) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * DualMma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename DualMma::IteratorA iterator_A0( + params.params_A0, + ptr_A0, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A0); + + typename DualMma::IteratorB0 iterator_B0( + params.params_B0, + ptr_B0, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B0); + + typename DualMma::IteratorB1 iterator_B1( + params.params_B1, + ptr_B1, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + + // Construct thread-scoped matrix multiply + typename DualMma::FragmentC accum0; + typename DualMma::FragmentC accum1; + accum0.clear(); + accum1.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accum0, accum1, + iterator_A0, iterator_B0, iterator_B1, + accum0, accum1); + } + + // + // Epilogue + // + + OutputOp0 output_op_0(params.output_op_0); + OutputOp1 output_op_1(params.output_op_1); + OutputOp2 output_op_2(params.output_op_2); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.n() * DualMma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == DualGemmMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D; + } + + // Tile iterator loading from source tensor. + typename Epilogue0::OutputTileIterator iterator_C0( + params.params_C0, + ptr_C0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue0::OutputTileIterator iterator_D0( + params.params_D0, + ptr_D0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D1( + params.params_D1, + ptr_D1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D2( + params.params_D2, + ptr_D2, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + DualEpilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C0 = iterator_D0; + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + typename Epilogue0::OutputTileIterator source_iters[] = { + iterator_C0, iterator_C1 + }; + const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); + epilogue( + output_op_0, output_op_1, output_op_2, + iterator_D0, iterator_D1, iterator_D2, + accum0, accum1, + source_iters, + writeToD2 + ); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h new file mode 100644 index 0000000000000000000000000000000000000000..c92af193d21860b69ba19275dbff464c95a96a71 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000000000000000000000000000000000..003b6f6c60d60ffd8660f531e7d332f3caec4dd7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h new file mode 100644 index 0000000000000000000000000000000000000000..bcb34d8970b2f8e4901426db4e6f250934cbbad3 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -0,0 +1,424 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once +#include "cutlass/array.h" +#include CUDA_STD_HEADER(cassert) +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + ///< Output operator + typename OutputOp0_, + typename OutputOp1_, + typename OutputOp2_, + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + bool StoreD0 = true, + bool StoreD1 = true, + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class DualEpilogue { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + +public: + + /// Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) + { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed + + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); + } + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; + } + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[3]; + + apply_output_operator_(output_fragment, + output_op0, output_op1, output_op2, + aligned_accum_fragment0[0], aligned_accum_fragment1[0], + source_fragment); + + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment[2]); + ++dest2; + } + } + } + +private: + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[3], + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + + OutputAccessType* output_frag_ptr[3] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1]), + reinterpret_cast(&output_fragment[2]) + }; + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast(&aligned_accum_fragment0), + reinterpret_cast(&aligned_accum_fragment1) + }; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1]) + }; + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operators + output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h new file mode 100644 index 0000000000000000000000000000000000000000..754719033e3ac19c7c6101db94d5a45b5d4e8b97 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_base.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DualMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy0::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator0::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB0 = + MatrixShape; + using ShapeB1 = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator0::LayoutA LayoutA() { + return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator0::LayoutB LayoutB0() { + return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator1::LayoutB LayoutB1() { + return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB0 operand_B0_ref() { + return TensorRefB0{operand_B0.data(), LayoutB0()}; + } + CUTLASS_HOST_DEVICE + TensorRefB1 operand_B1_ref() { + return TensorRefB1{operand_B1.data(), LayoutB1()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h new file mode 100644 index 0000000000000000000000000000000000000000..5109b41057ce370b1a18019dfe1709611446174f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "dual_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B0 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B0 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over tiles of B1 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B1 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DualMmaMultistage : + public DualMmaBase { +public: + ///< Base class + using Base = DualMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B0 operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of B1 operand in global memory + using IteratorB1 = IteratorB1_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorB1 = SmemIteratorB1_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), + smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B0.set_iteration_index(group_start_B * + IteratorB0::kAccessesPerVector); + iterator_B1.set_iteration_index(group_start_B * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B); + this->smem_iterator_B1_.set_iteration_index(group_start_B); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum0, + FragmentC &accum1, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB0 iterator_B0, + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC const &src_accum0, + FragmentC const &src_accum1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B0.set_iteration_index(0); + iterator_B1.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum0 = src_accum0; + accum1 = src_accum1; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + typename IteratorB0::AccessType zero_B; + zero_B.clear(); + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_); + last_smem_iterator_B0.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B0.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B0; + } + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_); + last_smem_iterator_B1.set_iteration_index(0); + + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B1.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B1; + } + } + + // Waits until stages up to the previous (kStages-2)th stage have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator0 warp_mma0; + Operator1 warp_mma1; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum0, tmp_accum1; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum0.clear(); + tmp_accum1.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) { + warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + } + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma0( + tmp_accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + tmp_accum0 + ); + warp_mma1( + tmp_accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum1 + ); + + if (warp_mma_k == 0) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + tmp_accum0.clear(); + tmp_accum1.clear(); + } + } else { + warp_mma0( + accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + warp_mma1( + accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum1 + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until stages up to the previous (kStages-2)th stage have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + } + + // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..609168c35480c88314229e15a4b8e405e35f5e7f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/51_hopper_gett/gett_kernel.cuh @@ -0,0 +1,139 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace example { + +// +// GETT entry point +// +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +cutlass::Status +gett_kernel( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + // TileShape -- GETT configuration + // Specify the number of elements to take from each mode + // BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...) + + // Take 128 from m0, 128 from n0, 64 from k0 + using TileShape = Shape, Shape<_128>, Shape<_64>>; + + /* Other examples: + * Take 32 elements from m0 and 4 elements from m1 + * Take 64 elements from n0 and 2 elements from n1 + * Take 8 elements from k0 and 8 elements from k1 + **/ + // using TileShape = Shape, Shape<_64,_2>, Shape<_8,_8>>; + + using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC>; + + // No changes are required to the default epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + ElementC, + StrideC, + StrideD, + EpilogueThreadOp, + cutlass::gemm::EpilogueDefault>>; + + // CollectiveMma for GETTs can be built using the CollectiveBuilders + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + // The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShapeMNKL, + CollectiveMainloop, + CollectiveEpilogue>; + + using GettOperator = cutlass::gemm::device::GemmUniversalAdapter; + + typename GettOperator::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape_mnkl, + { ptr_A, stride_a_mkl, ptr_B, stride_b_nkl }, + { {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl } + }; + +#if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print("Problem shape:"); + print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n"); + print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n"); + print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n"); + print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n"); + print("TileSape:"); print(TileShape{}); print("\n"); +#endif + + GettOperator op; + return op(args, stream); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..959e1e6392cf68f4e2b2da8e821bfa809e27b93c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass { + ///Forward declaration + struct CudaHostAdapter; +} + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_, + class GatherA_, + class GatherB_ +> +class GemmGather +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + using GatherA = GatherA_; + using GatherB = GatherB_; + + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _2> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + args.gather_A, + args.gather_B + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = Shape<_1,_1,_1>{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) + Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + // Wait for all threads in the thread block + __syncthreads(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + thread_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..b4cafccbb555a14256fecff25e55f0f514e86173 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/numeric/math.hpp" + +namespace example +{ + +// Naive grid-stride loop implementation of gather +template +__global__ void +gather_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[k] = input_b[i + func(j) * stride_divmod.divisor]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +gather(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + gather_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +// Naive grid-stride loop implementation of scatter +template +__global__ void +scatter_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[i + func(j) * stride_divmod.divisor] = input_b[k]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +scatter(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + scatter_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4292e018c16f58374181dfaf70d04e7bbd10fd44 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass::epilogue::collective { + +/// Applies an element wise operation to all elements within the fragment +/// and scatter-writes them out to destination storage. +/// GatherC and ScatterD are types of user-defined functions that apply the +/// transoformation of the strided coordinate (e.g. through an index array). +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + class GatherC_, + class ScatterD_ +> +class EpilogueGatherScatter { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + // Every epilogue needs these two GmemTiledCopy{C,D} aliases. + // If you don't know what they should be, just use void. + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using GatherC = GatherC_; + using ScatterD = ScatterD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread_params{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + GatherC gather_C{}; + ScatterD scatter_D{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueGatherScatter(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l) + Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +} // namespace cutlass::epilogue::collective + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh new file mode 100644 index 0000000000000000000000000000000000000000..0cb1aad901891867c87c62c999be4e13de94d598 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_kernel.cuh @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple permutation kernel implementation. +*/ + +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" +#include "cute/numeric/numeric_types.hpp" + +namespace example +{ + +/** + * Assumes column-major input (M mode is contiguous, N mode is strided). + * For row major, the inputs must be switched accordingly. +*/ +template +__global__ void +permute_kernel(Element const* __restrict__ input, + Element* __restrict__ output, + Permute permute, + int64_t num_elems, + cutlass::FastDivmod stride_divmod) +{ + // CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor + Element const * input_b = input + blockIdx.z * num_elems; + Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems); + for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x) + { + int i, j; + stride_divmod(j, i, k); + output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor]; + } +} + +template +void permute(Element const* input, + Element * output, + int64_t num_elems, + int stride, + int batch_count, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int64_t num_elems_upcast = num_elems / factor; + Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast); + + cutlass::FastDivmod stride_divmod(stride); + dim3 blocks(hw_info.sm_count, 1, batch_count); + permute_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + permute_upcast, + num_elems_upcast, + stride_upcast); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e7a45d7281b2b9d88414a7e37c92f20f955ebb40 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Additional permutation information for the example. +*/ + +#include "cutlass/layout/permute.h" +#include "cutlass/gemm/gemm.h" + +namespace example +{ + +using namespace cute; + +// This struct is specialized below for different CUTLASS 2.x permutation ops +// to describe the operation in terms of target CuTe shape and stride order. +template +struct PermuteTraits {}; + +// Use X as a placeholder for shape division result +using X = Underscore; + +// Reshape a rank-2 shape into a multidimensional shape. +// Input: +// shape = (A, B, ...) +// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) +// Output: +// ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...) +template +constexpr auto +reshape(Shape const& shape, TargetShape const& target_shape) +{ + if constexpr (is_tuple::value) { + return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); }); + } + else { + auto idx = find_if(target_shape, [](auto x){ return is_underscore{}; }); + constexpr int I = decltype(idx)::value; + static_assert(I < tuple_size_v, "Each mode of TargetShape must contain a placeholder X"); + auto divisors = remove(target_shape); + assert(shape % product(divisors) == 0); + return replace(target_shape, shape / product(divisors)); + } +} + +// Given a tensor layout, compute a permutation layout consisting of: +// - sub-modes corresponding to the implied multidimensional shape of the source tensor +// - strides accounting for the permutation operation being performed +template +constexpr auto +make_permute_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); + } + else { + if constexpr (cutlass::layout::is_trivial_permute) { + // Special case for NoPermute. Use a depth-2 layout for consistency with other permutations. + using ShapeProfile = tuple, tuple, tuple>; + return unflatten(layout, ShapeProfile{}); + } + else { + // Here's where the permutation layout is actually built + using ShapeProfile = typename PermuteTraits::ShapeProfile; + using StrideOrder = typename PermuteTraits::StrideOrder; + return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{}); + } + } +} + +namespace detail +{ + +template +struct is_constant_pred { + template + constexpr auto operator()(T) { + return is_constant{}; + } +}; + +template +constexpr auto +inverse_impl(Permutation const & perm, seq) { + return cute::make_tuple(Int{})>{}...); +} + +} // namespace detail + +// Compute an inverse of a permutation represented as a tuple of cute::Int<> +template +constexpr auto +inverse(Permutation const & perm) { + auto flat_perm = flatten(perm); + return unflatten(detail::inverse_impl(flat_perm, tuple_seq{}), perm); +} + +template +using inverse_t = decltype(inverse(T{})); + +// Given a rank-2 layout of tensor that is assumed to have been permuted, +// compute the original rank-2 layout of the tensor prior to the permutation. +// This is needed to form the correct input to the standalone permutation kernel. +template +constexpr auto +make_original_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); + } + else { + using ShapeProfile = typename PermuteTraits::ShapeProfile; + auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{})); + using IndexOrder = typename PermuteTraits::IndexOrder; + auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get(re_shape); }); + using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; + // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); + // print("Original shape: "); print(orig_shape); print("\n"); + return make_ordered_layout(product_each(orig_shape), OrigOrder{}); + } +} + +/////////////// Tensor4DPermute0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +/////////////// Tensor4DPermuteBMM0321 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1>, Step<_3>>; + using StrideOrder = Step, Step<_2>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape>, Shape, Shape>; + using IndexOrder = Step, Step<_2>, Step<_1,_3>>; + using StrideOrder = Step, Step<_1>, Step<_3>>; +}; + +/////////////// Tensor4DPermuteBMM0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1,_2>, Step<_3>>; + using StrideOrder = Step, Step<_0>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape>, Shape>; + using IndexOrder = Step, Step<_1>, Step<_2,_3>>; + using StrideOrder = Step, Step<_0,_2>, Step<_3>>; +}; + +/////////////// Tensor5DPermute02413 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int,X>, Shape>; + using IndexOrder = Step, Step<_4,_1,_3>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_1,_4,_2>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_1,_4,_2>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_4,_1,_3>, Step<_5>>; +}; + +/////////////// Tensor5DPermute20314 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_3,_1,_4>, Step<_5>>; + using StrideOrder = Step, Step<_0,_2,_4>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_2,_4,_1>, Step<_5>>; + using StrideOrder = Step, Step<_0,_3,_1>, Step<_5>>; +}; + +} // namespace example diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e8ea5330b3163a26e09be16e26f459b02cf5bf02 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "54_fp8_hopper_warp_specialized_gemm\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98c7440679d455c12306c8e07078aa670f47328b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 100; + int warmup = 10; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_mixed_dtype_gemm\n\n" + << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + // If no scales, initialize with 1 so we can use the same kernel to dequantize the data + float scope_max = 1.0f, scope_min = 1.0f; + if (options.mode != MixedDtypeGemmMode::ConvertOnly) { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + scope_max = 2.f; + scope_min = 0.1f; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + // If no bias, initialize with 0 so we can use the same kernel to dequantize the data + float scope_max = 0.0f, scope_min = 0.0f; + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + scope_max = 2.0f; + scope_min = -2.0f; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..d3d0982d14cc84c9d3bcf0cace2973a0ca931302 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include + +#include "cutlass/util/print_error.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" + +using namespace cute; + +struct AmpereUnpredicatedFprop { + // + // Static config for conv problem shape + // + using D = _6; + using H = _4; + using W = _4; + + using T = _3; + using R = _3; + using S = _3; + + using Z = _4; + using P = _2; + using Q = _2; + + using C = _64; + using K = _128; + + // Tiler config + using Tiler_K = decltype(cute::min(K{}, _128{})); + using Tiler_C = decltype(cute::min(C{}, _32{})); + using Tiler_N = _4; + using TileM = Tiler_K; + using TileN = Shape; + using TileK = Shape; + using PIPE = _3; + using TilerFlt = Shape; + using TilerAct = Shape; + using TilerOut = Shape; + + using TileSizeM = Int; + using TileSizeN = Int; + using TileSizeK = Int; + static constexpr int Stages = PIPE::value; + + using ElementFlt = tfloat32_t; + using ElementAct = tfloat32_t; + using ElementOut = float; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_32,Underscore>>; + + static constexpr int MaxThreadsPerBlock = size(TiledMma{}); + static constexpr int MinBlocksPerMultiprocessor = 1; + + union SharedStorage { + struct { + ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})]; + ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})]; + } mainloop; + + struct { + ElementOut sCMatrix[size(TileM{}) * size(TileN{})]; + } epilogue; + }; + + // + // Stencil tensor + // + + using GmemLayoutFlt = decltype(make_ordered_layout( + Shape< K, Shape< C, T, R, S>>{}, + tuple<_4, tuple<_0,_3,_2,_1>>{})); + + // We have 64 elements * 32b each in the major mode that we can vectorize + // Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4 + // Rest along the minor mode + using GmemTiledCopyFlt = decltype(make_tiled_copy( + Copy_Atom, ElementFlt>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutFlt = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomFlt = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomFlt = Copy_Atom; + + // + // Activation tensor + // + + // Activation tensor is major in the contraction mode, so vectorize that mode first + // Then lay out the rest of the threads along the other mode + using GmemTiledCopyAct = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutAct = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomAct = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomAct = Copy_Atom; + + // + // Output tensor + // + + using GmemTiledCopyOut = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride<_1, _8>>{}, + Layout>{})); + + using SmemCopyAtomOut = Copy_Atom, ElementOut>; + + // This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability + using SmemLayoutOut = Layout>; + + // + // Conv functor + // + template + void __device__ + operator()(cute::Tensor mFlt, // ( K, (C,T,R,S)) + TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S)) + TensorOutput mOut, // ( K, (N,Z,P,Q)) + char* smem_buf) const { + using namespace cute; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm80CpAsyncUnpredicated, + Shape, + ElementFlt, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + ElementAct, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + TiledMma, + GmemTiledCopyFlt, + SmemLayoutAtomFlt, + SmemCopyAtomFlt, + cute::identity, + GmemTiledCopyAct, + SmemLayoutAtomAct, + SmemCopyAtomAct, + cute::identity>; + + TiledMma tiled_mma; + Tensor accum = partition_fragment_C(tiled_mma, TilerOut{}); + clear(accum); + + // Set up tensors + // NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW + Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k') + Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1) + Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n') + + // Compute m_coord and n_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk)); + auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk)); + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k') + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1) + Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N) + + auto k_tile_iter = cute::make_coord_iterator(size<2>(gA)); + int k_tile_count = size<2>(gA); + + CollectiveMainloop collective_mma; + collective_mma( + accum, + gA, + gB, + accum, + k_tile_iter, k_tile_count, + Underscore{}, // no residue since we do not support predication + threadIdx.x, + smem_buf); + + // + // Epilogue + // + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{}); + + auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x); + auto tCrC = smem_thr_copy_C.retile_S(accum); + auto tCsC = smem_thr_copy_C.partition_D(sC); + copy(smem_tiled_copy_C, tCrC, tCsC); + + __syncthreads(); + + GmemTiledCopyOut gmem_tiled_copy_C; + auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x); + auto tDsC = gmem_thr_copy_C.partition_S(sC); + auto tDgC = gmem_thr_copy_C.partition_D(gC); + copy(gmem_tiled_copy_C, tDsC, tDgC); + + #if 0 + if (thread0()) { + print("mAct = "); print(mAct); print('\n'); + print("mFlt = "); print(mFlt); print('\n'); + print("mOut = "); print(mOut); print('\n'); + print("gA = "); print(gA); print('\n'); + print("gB = "); print(gB); print('\n'); + print("gC = "); print(gC); print('\n'); + print("sA = "); print(sA.layout()); print('\n'); + print("sB = "); print(sB.layout()); print('\n'); + print("sC = "); print(sC.layout()); print('\n'); + print("tAgA = "); print(tAgA.layout()); print('\n'); + print("tBgB = "); print(tBgB.layout()); print('\n'); + print("tAsA = "); print(tAsA.layout()); print('\n'); + print("tBsB = "); print(tBsB.layout()); print('\n'); + print("tCsA = "); print(tCsA.layout()); print('\n'); + print("tCsB = "); print(tCsB.layout()); print('\n'); + print("tCrC = "); print(tCrC.layout()); print('\n'); + print("tCsC = "); print(tCsC.layout()); print('\n'); + print("tDsC = "); print(tDsC.layout()); print('\n'); + print("tDgC = "); print(tDgC.layout()); print('\n'); + print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n'); + print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n'); + print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n'); + print("k_tile_count = "); print(size<2>(gA)); print('\n'); + print("k_tile_iter = "); print(*k_tile_iter); print('\n'); + print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n'); + } + #endif + } +}; + +template +inline int +fprop_reference( + TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S)) + TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S)) + TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q)) + TensorOut mOutputRef) { + int32_t N = size<1,0>(mOutputRef); + int32_t Z = size<1,1>(mOutputRef); + int32_t P = size<1,2>(mOutputRef); + int32_t Q = size<1,3>(mOutputRef); + int32_t T = size<1,3>(mStencil); + int32_t R = size<1,2>(mStencil); + int32_t S = size<1,1>(mStencil); + int32_t C = size<1,0>(mStencil); + + size_t K = static_cast(size<0>(mOutputRef)); + size_t NZPQ = static_cast(size<1>(mOutputRef)); + size_t CTRS = static_cast(size<1>(mStencil)); + +#if defined(_OPENMP) + #pragma omp parallel for +#endif + for (size_t logical_m = 0; logical_m < K; ++logical_m) { + for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) { + auto accumulator = float(0); + for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) { + accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k); + } + mOutputRef(logical_m, logical_n) = accumulator; + } + } + + return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..09bb9b6599062a8bc69cd7474a361d031943c652 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size + constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); + constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes; +} + +} // namespace detail + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8bf0bf37202eaadd2f8e7f0573c196c7fa77a7be --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cutlass::gemm { + +// Standard non-persistent kernel with a single producer warp, and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp is waiting on griddepcontrol. +// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and +// according to prefetch ratio. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { }; + +// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not +// wait on griddepcontrol and loads immediately. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { }; + +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch +> +struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4d9325b4e20d92587fa9f70e607223c4d7b8cf47 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -0,0 +1,871 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/arch/grid_dependency_control.h" + +#include "dispatch_policy_extra.hpp" + +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +constexpr int PrefetchStages = 4; +constexpr int PrefetchInitialStages = 1; +// This determines how much shmem we set aside for prefetch. +// We don't reuse anything loaded by prefetcher, so we can keep +// loading into the same place -- there will be a conflict when +// writing, but it doesn't affect performance as much as the doors +// that this opens. +constexpr int PrefetchStagesActual = 1; + +} // namespace detail + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedWithPrefetch, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using PrefetcherPipeline = cutlass::PrefetchPipeline; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages); + static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages); + + using PrefetchSmemLayoutA = decltype(make_layout(make_shape( + cute::Int(SmemLayoutA{})>{}, + cute::Int(SmemLayoutA{})>{}, + cute::Int{}))); + + static constexpr auto prefetch_smem_size = cute::cosize_v; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Defined outside the class where it's used, to work around MSVC issues + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned smem_prefetch; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + PrefetcherPipelineStorage prefetcher_pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.overlap_ratio, + args.prefetch_ratio + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (args.overlap_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + if (args.prefetch_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + return true; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // We have to wait on dependent grids because of B. + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + // Don't wait on dependent grids when loading `A`, because + // we assume `A` (weights) are static. + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + prefetch_MK( + Params const& mainloop_params, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; + float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; + int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); + prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages); + + Tensor sA = make_tensor( + make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + uint32_t prefetcher_stage = 0; + uint32_t prefetcher_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) { + + if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) { + break; + } + + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages); + using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; + BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); + + int write_stage = 0; + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + ++k_tile_iter; + + prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase); + } + prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c85edb71bb7396ee3b83afd1fe9d56510f1a759 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float overlap_ratio = 0.5f, prefetch_ratio = 0.5f; + int iterations = 1000; + int n = 64, m = 1280, k = 8192, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f); + cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "63_hopper_gemm_with_weight_prefetch\n\n" + << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --p= Prefetch ratio\n" + << " --o= Overlap ratio\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "63_hopper_gemm_with_weight_prefetch" << + " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + /// Compute effective bandwidth in GB/sec + double effective_bandwidth( + double runtime_s, + size_t bytes_a, + size_t bytes_b, + size_t bytes_c, + size_t bytes_d + ) const + { + static double const kBytesPerGiB = double(1ull << 30); + + double bytes_in = + (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A + (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B + (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C + double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D + + double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; + return gb_total / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73655ad2ace6674ed0ea7acb5b185505379d599e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +// GEMM + Prefetch for the A tensor + (optional) split DMA warps +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v + > +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr bool SplitWarps = cute::is_same_v; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) PrefetcherPipelineStorage prefetcher; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK. + // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused. + // Both modes use Warp1 to prefetch. + enum class ProducerWarpRole { + Warp0 = 0, + PrefetchMK = 1, + Warp2 = 2, + UnusedWarp = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + if (warp_group_role == WarpGroupRole::Producer && ( + producer_warp_role == ProducerWarpRole::Warp0 || + producer_warp_role == ProducerWarpRole::Warp2)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + bool should_prefetch = params.mainloop.prefetch_ratio > 0; + using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline; + typename PrefetcherPipeline::Params prefetcher_pipeline_params; + prefetcher_pipeline_params.num_prefetchers = 1; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + prefetcher_pipeline_params.should_prefetch = should_prefetch; + prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; + } + PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + // Non-prefetcher warps arrive and wait, + // Prefetcher warp can go ahead without waiting. + cute::cluster_arrive_relaxed(); + if (warp_group_role != WarpGroupRole::Producer || + producer_warp_role != ProducerWarpRole::PrefetchMK) { + cute::cluster_wait(); + } + return [] () {}; + } + else { + // __syncthreads() but only for non prefetcher warps + if (should_prefetch) { + + // Use a named barrier to let the prefetcher warp start loading into the L2 + // without waiting to sync with all other warps. + // All other warps need to sync because the mainloop pipeline init + // should be visible to all of them. + // Prefetcher has its own barriers, and the only warps it would need to sync + // with would be the DMA warps. + using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; + auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z, + /*id*/ 0); + // Prefetcher warp doesn't arrive on this barrier. + auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, + /*id*/ 1); + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + __syncwarp(); + prefetcher_arrive_barrier.arrive(); + } + else if (warp_group_role == WarpGroupRole::Producer) { + prefetcher_arrive_barrier.arrive_and_wait(); + cluster_arrive_barrier.arrive_and_wait(); + } + else { + prefetcher_arrive_barrier.arrive(); + cluster_arrive_barrier.arrive_and_wait(); + } + } else { + __syncthreads(); + } + return [] () {}; + } + } (); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + if (producer_warp_role == ProducerWarpRole::Warp0) { + if constexpr(SplitWarps) { + collective_mainloop.load_NK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + else { + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) { + collective_mainloop.load_MK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) { + collective_mainloop.prefetch_MK( + params.mainloop, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..157d201adec5ecd36ad76dedefd72ac3404d54a6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/container/array.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// MSVC work-around +template +struct PrefetcherPipelineSharedStorage { + using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier; + using Barrier = cutlass::arch::ClusterBarrier; + + TransactionBarrier tma_barrier[Stages]; + Barrier producer_ready_barrier; +}; + +} // end namespace detail + +using namespace cute; + +// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction +// barrier providing control over the number of concurrent outstanding TMA loads. +// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset. +// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks +// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads. +template +class PrefetchPipeline { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = detail::PrefetcherPipelineSharedStorage; + + using TransactionBarrier = typename SharedStorage::TransactionBarrier; + using Barrier = typename SharedStorage::Barrier; + using PrefetcherBarrierType = typename TransactionBarrier::ValueType; + + struct Params { + uint32_t transaction_bytes = 0; + uint32_t num_prefetchers = 1; + bool should_prefetch = false; + }; + + // Constructor + CUTLASS_DEVICE + PrefetchPipeline(SharedStorage& storage, Params params) + : params_(params) + , tma_barrier_ptr_(&storage.tma_barrier[0]) + , producer_ready_barrier_ptr_(&storage.producer_ready_barrier) { + + int lane_predicate = cute::elect_one_sync(); + if (params.should_prefetch && lane_predicate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + tma_barrier_ptr_[i].init(params.num_prefetchers); + } + producer_ready_barrier_ptr_[0].init(1); + } + } + + CUTLASS_DEVICE + void producer_arrive() { + if (params_.should_prefetch) { + producer_ready_barrier_ptr_[0].arrive(); + } + } + + CUTLASS_DEVICE + bool have_producers_arrived() { + if (params_.should_prefetch) { + uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0); + auto barrier_status = static_cast(barrier_status_); + if (barrier_status == BarrierStatus::WaitDone) { + return true; // exit prefetcher loop + } + return false; + } + return true; + } + + CUTLASS_DEVICE + void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) { + if (params_.should_prefetch) { + if (should_wait) { + tma_barrier_ptr_[stage].wait(phase ^ 1); + } + tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); + } + } + + CUTLASS_DEVICE + void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) { + if (params_.should_prefetch) { + stage++; + if (stage == Stages) { + stage = 0; + phase ^= 1; + } + } + } + + CUTLASS_DEVICE + void prefetcher_tail(uint32_t stage, uint32_t phase) { + if (params_.should_prefetch) { + // Wait on any already-issued loads + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < stage; ++i) { + tma_barrier_ptr_[i].wait(phase); + } + } + } + + CUTLASS_DEVICE + PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) { + return reinterpret_cast(&tma_barrier_ptr_[stage]); + } + +private : + TransactionBarrier* tma_barrier_ptr_ = nullptr; + Barrier* producer_ready_barrier_ptr_ = nullptr; + Params params_; + +}; + +} // end namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..85aff7567234e9812d0cd09f5a85f9fec1840bdc --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + bool help = false; + bool verify = true; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int warmup = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; + float epsilon = 0.02f; + float non_zero_floor = 1.f; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("warmup", warmup); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("verify", verify); + cmd.get_cmd_line_argument("epsilon", epsilon); + cmd.get_cmd_line_argument("non-zero-floor", non_zero_floor); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --verify= Verify the results.\n\n" + << " --epsilon= The epsilon value for comparing the results.\n\n" + << " --non-zero-floor= The none zero floor for comparing the results.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "67_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff446a91353ab2b2acd27528741a8ac076e90b3f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling/hopper_fp8_commandline.hpp @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +using RasterOrderOptions = cutlass::gemm::kernel::detail::RasterOrderOptions; +template +struct Options { + using ProblemShape = _ProblemShape; + + bool help = false; + + float alpha = 1.f, beta = 0.f; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, groups = 10; + std::string benchmark_path; + std::vector problem_sizes_after_alignment_host; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + int const k_alignment = 128; + int const m_alignment = 128; + int const n_alignment = 128; + + RasterOrderOptions raster_order; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster_order = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster_order = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster_order = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_after_alignment_host.clear(); + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_after_alignment_host.reserve(groups); + problem_sizes_host.reserve(groups); + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 0) { + m = m_alignment * (rand() % (64 * alignment / m_alignment)); + } + if (n < 0) { + n = n_alignment * (rand() % (64 * alignment / n_alignment)); + } + if (k < 0) { + k = k_alignment * (rand() % (32 * alignment / k_alignment)); + } + problem_sizes_after_alignment_host.push_back({m, n, k}); + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent_after_alignment, extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + extent.at(i) = x; + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent_after_alignment.at(i) = x; + } + + problem_sizes_after_alignment_host.push_back({extent_after_alignment.m(), extent_after_alignment.n(), extent_after_alignment.k()}); + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + groups = static_cast(problem_sizes_after_alignment_host.size()); + + return true; + } + + /// Calculate memory bandwidth statistics + template + auto gbps(double runtime_s) const { + double total_read_bytes = 0; + double total_write_bytes = 0; + + // Calculate bytes read and written for each problem + for (int i = 0; i < groups; ++i) { + auto problem = problem_sizes_host.at(i); + auto M = cute::get<0>(problem); + auto N = cute::get<1>(problem); + auto K = cute::get<2>(problem); + + if (M > 0) { // Only count active problems + // Matrix A: M*K elements read + total_read_bytes += M * K * sizeof(ElementA); + + // Matrix B: K*N elements read + total_read_bytes += K * N * sizeof(ElementB); + + // Matrix C: M*N elements read (for beta operation) + total_read_bytes += M * N * sizeof(ElementC); + + // Block scales for A and B + auto blockscale_shape = cute::shape(cute::get<1>(cute::zipped_divide(cute::make_layout(problem), TileShape{}))); + auto blockscale_m = cute::get<0>(blockscale_shape); + auto blockscale_n = cute::get<1>(blockscale_shape); + auto blockscale_k = cute::get<2>(blockscale_shape); + auto groupscale_m = blockscale_m * ScaleMsPerTile; + auto groupscale_n = blockscale_n * ScaleNsPerTile; + + total_read_bytes += groupscale_m * blockscale_k * sizeof(ElementBlockScale); // A scales + total_read_bytes += groupscale_n * blockscale_k * sizeof(ElementBlockScale); // B scales + + // Matrix D: M*N elements written + total_write_bytes += M * N * sizeof(ElementD); + } + } + + return (total_read_bytes + total_write_bytes) / 1.0e9 / runtime_s; + } + + double bandwidth_util(double eff_bandwidth) const { + int memoryClockRate; + int memoryBusWidth; + cudaDeviceGetAttribute(&memoryClockRate, cudaDevAttrMemoryClockRate, 0); + cudaDeviceGetAttribute(&memoryBusWidth, cudaDevAttrGlobalMemoryBusWidth , 0); + double bw = 2.0 * memoryClockRate * (memoryBusWidth / 8) / 1.0e6; + return eff_bandwidth / bw * 100.0; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling\n\n" + << " Hopper FP8 Grouped GEMM using a Warp Specialized kernel with Blockwise Scaling.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --benchmark= Executes a benchmark problem size.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "68_hopper_fp8_warp_specialized_grouped_gemm_with_blockwise_scaling" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Number of real-valued multiply-adds + uint64_t fmas = 0ull; + + for (auto const [m, n, k] : problem_sizes_host) { + fmas += static_cast(m) * + static_cast(n) * + static_cast(k); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8568b467dd87ca5c1b0111e6bd4584b9d3da0960 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/69_hopper_mixed_dtype_grouped_gemm/grouped_mixed_dtype_utils.hpp @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +#include "../55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp" + +template +class GroupedMixedDtypeOptions : public MixedDtypeOptions { +public: + using ProblemShape = cutlass::gemm::GroupProblemShape>; + using UnderlyingProblemShape = typename ProblemShape::UnderlyingProblemShape; + + int groups = 6; + int c = 512; + std::string benchmark_path; + std::vector problem_sizes_host; + + GroupedMixedDtypeOptions() : MixedDtypeOptions() + { + m = 1024; + n = 2048; + k = 512; + }; + + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + cmd.get_cmd_line_argument("c", c); + MixedDtypeOptions::parse(argc, args); + + problem_sizes_host = benchmark_path.empty() ? randomize_problems(cmd) : load_benchmark_problems(); + } + + std::ostream& print_usage(std::ostream& out) const { + out << "69_hopper_mixed_dtype_grouped_gemm\n\n" + << "Options:\n" + << " --help Display this usage statement\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --c= Sets the chunk size for scaling the quantized weights\n" + << " --groups= Sets the number of individual GEMM problems\n" + << " --mode= The mode to run the gemm\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --iterations= Number of profiling iterations\n" + << " --warmup= Number of warmup iterations\n" + << " --benchmark= Executes a benchmark problem size\n"; + return out; + } + + double gflops(double runtime_s) const { + uint64_t fmas = std::accumulate(problem_sizes_host.begin(), problem_sizes_host.end(), 0ULL, + [](uint64_t sum, const UnderlyingProblemShape& problem) { + return sum + static_cast(cute::get<0>(problem)) * + static_cast(cute::get<1>(problem)) * + static_cast(cute::get<2>(problem)); + }); + return (2.0 * fmas) / (runtime_s * 1e9); + } + +private: + static constexpr int tma_alignment_bits = 128; + const int alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + std::vector randomize_problems(cutlass::CommandLine& cmd) { + std::vector problems; + problems.reserve(groups); + + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + for (int i = 0; i < groups; ++i) { + int m = (cmd_line_m >= 0) ? cmd_line_m : alignment * ((rand() % 64) + 1); + int n = (cmd_line_n >= 0) ? cmd_line_n : this->n; + int k = (cmd_line_k >= 0) ? cmd_line_k : this->k; + + if (k % alignment != 0) { + throw std::runtime_error("Error: k dimension must be a multiple of " + std::to_string(alignment)); + } + problems.push_back({m, n, k}); + } + return problems; + } + + std::vector load_benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file) { + throw std::runtime_error("Failed to open benchmark file: " + benchmark_path); + } + + std::vector problems; + int idx; + std::string extent_str; + + while (file >> idx >> extent_str) { + if (idx < 0 || extent_str.empty()) break; + + std::vector tokens; + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + cutlass::gemm::GemmCoord extent; + for (int i = 0; i < std::min(3, static_cast(tokens.size())); ++i) { + int x = std::stoi(tokens[i]); + extent.at(i) = (x % alignment) ? x + (alignment - (x % alignment)) : x; + } + + if (extent.product()) { + problems.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problems.size()); + return problems; + } +}; + +template +void grouped_mixed_dtype_profiling( + Gemm& gemm, + const GroupedMixedDtypeOptions& options, + MixedDtypeResult& result, + const std::vector& alpha_host, + const std::vector& beta_host) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + std::cout << " Groups : " << options.groups << '\n' + << " Avg runtime : " << result.avg_runtime_ms << " ms\n" + << " GFLOPS : " << result.gflops << '\n'; +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c60d9e953e2ae4afee89269f0a405b5be19ff3b5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_common.hpp @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + static_assert(rA == 3 && rB == 3 && rC == 3); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = decltype(atom.accumulate_)::One; + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = decltype(atom.accumulate_)::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, prepend(make_layout(stages), _)); +} + +template +CUTE_DEVICE T warp_uniform(T a) { + return __shfl_sync(0xffffffff, a, 0); +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant>, + TAs...>, TMs...>) { + + return TiledMMA>, + TAs...>, TMs...>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_tiled_mma_sm100_ts( + TiledMMA, + TAs...>, TMs...>) { + return TiledMMA, + TAs...>, TMs...>{}; +} + +template +CUTLASS_DEVICE +void warpgroup_reg_set() { + if constexpr (RegCount < 128) { + cutlass::arch::warpgroup_reg_dealloc(); + } + else { + cutlass::arch::warpgroup_reg_alloc(); + } +} + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a33ce2d2ce04c357fe49ac6752e99e5f6e8c32b6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,396 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct NoMask { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return ceil_div(get<1>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + return; + } +}; + +struct ResidualMask : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<1>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct ResidualMaskForBackward : NoMask { + + using Base = NoMask; + + template + CUTLASS_DEVICE int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return 1; + } + return 0; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // if the sequence length does not divide the tile size evenly + if (get<1>(problem_size) % get<1>(tile_shape) != 0) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (! elem_less(pos, select<0,1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +// There are two ways to do causal if N_Q != N_K +// (1) The Q is at the beginning of the matrix +// (2) The Q is at the end of the matrix +template +struct CausalMask : NoMask { + + using Base = NoMask; + + static constexpr bool IsQBegin = kIsQBegin; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } else { + const int offset_q = get<1>(problem_size) - get<0>(problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape) + offset_q, get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + int trip_count = get_trip_count(blk_coord, tile_shape, problem_size); + if constexpr (IsQBegin) { + return std::min(trip_count, int(ceil_div(size<0>(tile_shape), size<1>(tile_shape)))); + } else { + const int offset_tile_q = (get<1>(problem_size) - get<0>(problem_size)) % get<1>(tile_shape); + return std::min(trip_count, int(ceil_div(get<0>(tile_shape) + offset_tile_q, get<1>(tile_shape)))); + } + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size) { + + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is the default setting. + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to set kIsQBegin=false + + if constexpr (IsQBegin) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } else { + const auto offset_q = get<1>(problem_size) - get<0>(problem_size); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if ((get<0>(pos) + offset_q < get<1>(pos)) || (get<1>(pos) >= get<1>(problem_size))) { + acc_qk(i) = -INFINITY; + } + } + } + } +}; + +template +struct CausalForBackwardMask : CausalMask, ResidualMaskForBackward { + + using Base = CausalMask; + + template + CUTLASS_DEVICE + void apply_mask( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size) { + + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + int offset_q = 0; + if constexpr (!kIsQBegin) { + offset_q = get<1>(problem_size) - get<0>(problem_size); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + bool masked = (get<0>(pos) + offset_q < get<1>(pos)) || !elem_less(pos, problem_size); + if (masked) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +struct VariableLength { + int max_length; + int* cumulative_length = nullptr; + int total_length = -1; + + CUTE_HOST_DEVICE operator int() const { + return max_length; + } +}; + +template struct is_variable_length_impl : std::false_type {}; +template<> struct is_variable_length_impl : std::true_type {}; +template constexpr bool is_variable_length_v = is_variable_length_impl>::value; + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Idx const& idx) { + return transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length(Shape const& shape, Coord const& coord, Idx const& idx) { + auto new_shape = apply_variable_length(shape, idx); + auto new_coord = transform_leaf(shape, coord, [&](auto const& s, auto const& c) { + if constexpr (is_variable_length_v) { + return cute::make_tuple(c, s.cumulative_length[idx]); + } + else { + return c; + } + }); + return cute::make_tuple(new_shape, new_coord); +} + +template +CUTE_HOST_DEVICE +constexpr auto +apply_variable_length_offset(Shape const& shape, Coord const& coord) { + auto idx = back(back(coord)); + auto result_shape = transform_leaf(shape, [&](auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx+1] - s.cumulative_length[idx]; + } + else { + return s; + } + }); + auto result_offset = transform_leaf(coord, shape, [&](auto const& c, auto const& s) { + if constexpr (is_variable_length_v) { + return s.cumulative_length[idx]; + } + else { + return _0{}; + } + }); + return cute::make_tuple(result_shape, result_offset); +} + +} // namespace cutlass::fmha::collective + +namespace cute { + +template<> +struct is_integral : true_type {}; + +CUTE_HOST_DEVICE +void print(cutlass::fmha::collective::VariableLength a) { + printf("Varlen<%d, %p>", a.max_length, a.cumulative_length); +} + +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..616357cb0ef3ef4a585be39825ea1f34807e561f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element, + class ElementAcc, + class TileShape, // Q, D, _ + class StrideO, // Q, D, B + class StrideLSE_, // Q, B + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdEpilogueTmaWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + +// using SmemLayoutO = decltypa(make_layout(append<3>(select<0,1>(TileShape_WG{}), _2{}))); + using SmemLayoutAtomO = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, Element, tuple_element_t<0, TileShape>, tuple_element_t<1, TileShape>>()); +// using SmemLayoutAtomO = decltype(make_ordered_layout(select<0,1>(TileShape{}), Step<_1, _0>{})); + using SmemLayoutO = decltype(tile_to_shape(SmemLayoutAtomO{}, replace<2>(TileShape{}, _2{}), Step<_2, _1, _3>{})); + using SmemLayoutO_ = SmemLayoutO; + using StrideLSE = StrideLSE_; + using ElementOut = Element; + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_O; + StrideO dO; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + using TMA_O = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor((Element*) nullptr, repeat_like(StrideO{}, 0), StrideO{}), + SmemLayoutO{}(_,_,_0{}) + )); + + + struct Params { + TMA_O tma_store_o; + + ElementAcc* ptr_LSE; + StrideLSE dLSE; + }; + + // FMHA and MLA have different input ProblemShapes; + // get problem_shape_O according to the input ProblemShape. + template + CUTLASS_DEVICE static constexpr + auto get_problem_shape_O ( + ProblemShape const& problem_shape) { + if constexpr (rank_v(ProblemShape{}))> == 2) { + return replace<1>(select<0,2,3>(problem_shape), get<2, 0>(problem_shape)); + } else { + return select<0,2,3>(problem_shape); + } + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + + auto ptr_O = args.ptr_O; + StrideO dO = args.dO; + + auto problem_shape_O = get_problem_shape_O(problem_shape); + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(problem_shape).max_length; + // for variable sequence lenght, the batch is in units of row_stride + get<2,1>(dO) = get<0>(dO); + get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); + // offset ptr by the amount we add back in later + ptr_O -= max_length_q * get<0>(dO); + } + } + + auto tma_store_o = make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(ptr_O, problem_shape_O, dO), + SmemLayoutO{}(_,_,_0{}) + ); + + return { + tma_store_o, + args.ptr_LSE, + args.dLSE + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_store_o.get_tma_descriptor()); + } + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaFwdEpilogueTmaWarpspecialized(const Params& params) : params(params) {} + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + + BlkCoord blk_coord = blk_coord_in; + uint32_t lane_predicate = cute::elect_one_sync(); + + using X = Underscore; + + int o0_index = 2 * get<0>(blk_coord); + int o1_index = 2 * get<0>(blk_coord) + 1; + + Tensor mO_qdl_p = params.tma_store_o.get_tma_tensor(get_problem_shape_O(problem_shape)); + // offset mode 0 by (max_length - real_length) + // offset mode 3,1 by cumulative_length + real_length + // the ptr is already offset by - max_length + // so in total this achieves + int offs_0 = 0; + int offs_2_1 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + int max_length_q = get<0>(params_problem_shape).max_length; + offs_0 = max_length_q - get<0>(problem_shape); + offs_2_1 = cumulative_length_q[get<2,1>(blk_coord)] + get<0>(problem_shape); + get<2,1>(blk_coord) = 0; + } + } + + Tensor mO_qdl = domain_offset(make_coord(offs_0, _0{}, make_coord(_0{}, offs_2_1)), mO_qdl_p); + + Tensor gO_qdl = local_tile(mO_qdl, TileShape{}, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gO = gO_qdl(_, _, _, _0{}, get<2>(blk_coord)); + Tensor sO = make_tensor(make_smem_ptr(shared_storage.smem_o.data()), SmemLayoutO{}); + auto block_tma = params.tma_store_o.get_slice(0); + Tensor tOsO = block_tma.partition_S(sO); + Tensor tOgO = block_tma.partition_D(gO); + + auto pipeline_release_state = pipeline_consumer_state; + + // O1 O2 + // one pipeline: O + // wait from corr, issue tma store on smem + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_0{}), tOgO(_,_,_,o0_index)); + } + tma_store_arrive(); + + pipeline.consumer_wait(pipeline_consumer_state); + ++pipeline_consumer_state; + + if (lane_predicate) { + copy(params.tma_store_o, tOsO(_,_,_,_1{}), tOgO(_,_,_,o1_index)); + } + tma_store_arrive(); + + tma_store_wait<1>(); + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + tma_store_wait<0>(); + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + pipeline.consumer_release(pipeline_release_state); + ++pipeline_release_state; + + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1e094bf42dd7235f69f9db5f8e5e569d313c7ce6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1218 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_tma_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class TileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + // Since shared memory is sufficient for FMHA, there is no need to reuse shared memory. + class OrderLoadEpilogue = cute::false_type +> +struct Sm100FmhaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using TileShape = TileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountKV = sizeof(Element_) == 1 ? 4 : 3; + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + // Reuse shared memory for V and O. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaUmmaAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static const int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + + static const int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static const int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static_assert(TransactionBytesLoadK == TransactionBytesLoadV, "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3d1e1e8be3b3738140b2d89a8f4f4bacdce34b77 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_epilogue_warpspecialized.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::collective { + +template< + class Element_, + class StrideO_ +> +struct Sm100FmhaGenEpilogueWarpspecialized { + + using Pipeline = cutlass::PipelineAsync<2>; + + using SmemLayoutO = Layout>; + using SmemLayoutO_ = SmemLayoutO; + using Element = Element_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideOOrig{}, 0)); + + struct TensorStorage { + + using SmemLayoutO = SmemLayoutO_; + cute::array_aligned> smem_o; + + }; + + struct Arguments { + Element* ptr_o; + StrideO dO; + }; + + using Params = Arguments; + + const Params& params; + + CUTLASS_DEVICE Sm100FmhaGenEpilogueWarpspecialized(const Params& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace = nullptr) { + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + /* no-op */ + } + + template + CUTLASS_DEVICE auto + store( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& shared_storage, + Pipeline& pipeline, typename Pipeline::PipelineState& pipeline_consumer_state) { + /* no-op */ + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..afef02245f316cfb03624d9ab2f4ca92553372c9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_gen_mainloop_warpspecialized.hpp @@ -0,0 +1,1113 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_load_cpasync_warpspecialized.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ElementOut_, + class TileShape_, + class StrideQ_, + class StrideNewK_, + class StrideNewV_, + class StrideK_, + class StrideV_, + class StrideO_, + class Mask_ = ResidualMask, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_1, _2, _1> +> +struct Sm100FmhaGenMainloopWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ElementAcc = ElementPV_; + using ElementOut = ElementOut_; + using TileShape = TileShape_; + using StrideQOrig = StrideQ_; + using StrideQ = decltype(replace<0>(StrideQ_{}, 0)); + using StrideNewK = StrideNewK_; + using StrideNewV = StrideNewV_; + using StrideCacheK = StrideK_; + using StrideCacheV = StrideV_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using StrideOOrig = StrideO_; + using StrideO = decltype(replace<0>(StrideO_{}, 0)); + using Mask = Mask_; + + static constexpr int StageCountQ = get<1>(TileShape{}) == 256 ? 1 : 2; + static constexpr int StageCountKV = 256 * 11 / get<1>(TileShape{}); + + using StagesQ = cutlass::gemm::collective::StageCount; + using StagesKV = cutlass::gemm::collective::StageCount; + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + using TileShapeQK = decltype(shape_div(TileShape{}, ThreadShape{})); + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + struct TensorStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineUmmaConsumerAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineUmmaConsumerAsync< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static_assert(cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v) == cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v), "K and V smem layouts must be of equal size"); + + using Load = Sm100FmhaLoadCpAsyncWarpspecialized< + Element, StrideQ, StrideNewK, StrideNewV, StrideCacheK, StrideCacheV, + TensorStorage, CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + PipelineQ, PipelineKV, TileShape, Mask + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_LOAD_32dp32b8x, SM100_TMEM_LOAD_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE = conditional_t(TileShapeQK{}) < _128{}, SM100_TMEM_STORE_32dp32b8x, SM100_TMEM_STORE_32dp32b32x>; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + const int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + int mask_tile_count = Mask{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_shape); + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, + (mask_tile_count == 1) && + (Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape) == 0), + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + // Masked iterations + mask_tile_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + softmax_step( + row_max, row_sum, stage, mask_tile_count == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale_softmax_log2, float scale_out, Vector const& v0, Vector const& v1, + GTensor& gO, CTensor const& cO, Shape const& g_shape, + Epilogue const& epilogue) { + + using ElementOut = typename GTensor::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOgO = mma.get_slice(0).partition_C(gO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOgO_i = tOgO.compose(make_layout(make_shape(_128{}, Int{}))); + + Tensor tOtO0 = tOtO_i; + tOtO0.data() = tOtO0.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO_i; + tOtO1.data() = tOtO1.data().get() + uint32_t(TmemAllocation::O1); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO0 = thr_tmem_load.partition_S(tOtO0); + Tensor tTMEM_LOADtO1 = thr_tmem_load.partition_S(tOtO1); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_LOADgO = thr_tmem_load.partition_D(tOgO_i); + + float row_max = std::max(v0(kIdxFinalRowMax), v1(kIdxFinalRowMax)); + float adj0 = ::exp2f(scale_softmax_log2 * (v0(kIdxFinalRowMax) - row_max)); + float adj1 = ::exp2f(scale_softmax_log2 * (v1(kIdxFinalRowMax) - row_max)); + float row_sum = adj0 * v0(kIdxFinalRowSum) + adj1 * v1(kIdxFinalRowSum); + float scale0 = scale_out * adj0 / row_sum; + float scale1 = scale_out * adj1 / row_sum; + + float2 scale0_f32x2 = make_float2(scale0, scale0); + float2 scale1_f32x2 = make_float2(scale1, scale1); + + // loop: + // TMEM_LOAD, TMEM_LOAD, FMUL2, FFMA2, STG + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO0_i = tTMEM_LOADtO0; + tTMEM_LOADtO0_i.data() = tTMEM_LOADtO0_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADtO1_i = tTMEM_LOADtO1; + tTMEM_LOADtO1_i.data() = tTMEM_LOADtO1_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMEM_LOADgO_i = tTMEM_LOADgO; + tTMEM_LOADgO_i.data() = tTMEM_LOADgO_i.data().get() + i * kCorrectionTileSize * stride<1>(gO); + + Tensor tTMrO0 = make_tensor(shape(tTMEM_LOADcO)); + Tensor tTMrO1 = make_tensor(shape(tTMEM_LOADcO)); + + copy(tiled_tmem_load, tTMEM_LOADtO0_i, tTMrO0); + copy(tiled_tmem_load, tTMEM_LOADtO1_i, tTMrO1); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO0); j += 2) { + float2 in0 = make_float2(tTMrO0(j), tTMrO0(j+1)); + float2 in1 = make_float2(tTMrO1(j), tTMrO1(j+1)); + float2 out; + cute::mul(out, scale0_f32x2, in0); + cute::fma(out, scale1_f32x2, in1, out); + tTMrO0(j) = out.x; + tTMrO0(j+1) = out.y; + } + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO0); + + Tensor tCs = recast(tTMrO0); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMgO_i = recast(tTMEM_LOADgO_i); + Tensor tSMrO_i = recast(tSMrO); + + // could use masking do this right for smaller D + if (get<0>(tTMEM_LOADcO(_0{})) < get<0>(g_shape)) { + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMgO_i); + } + } + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + const int kCorrectionTileSize = 32; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int(TileShape{}) / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + Epilogue const& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS0 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS0); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + Tensor tTMEM_LOADVrS1 = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS1); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + auto pipeline_o_release_state = pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + pipeline_o.consumer_wait(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + auto g_shape = select<0,2>(problem_shape); + auto mO = make_tensor(make_gmem_ptr(epilogue.params.ptr_o), append<3>(select<0,1>(TileShapePV{}), get<3>(problem_shape)), epilogue.params.dO); + auto gO = mO(_, _, get<2>(blk_coord)); + + correction_epilogue(params.scale_softmax_log2, params.scale_output, tTMEM_LOADVrS0, tTMEM_LOADVrS1, gO, cO, g_shape, epilogue); + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + + pipeline_o.consumer_release(pipeline_o_release_state); + ++pipeline_o_release_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..259f1b4766e501420442c6c393a5fda8e89004ed --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_cpasync_warpspecialized.hpp @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideNewK, + class StrideNewV, + class StrideCacheK, + class StrideCacheV, + class TensorStorage, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class PipelineQ, + class PipelineKV, + class TileShape, + class Mask +> +struct Sm100FmhaLoadCpAsyncWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + + const int* cache_batch_idx; + + const Element* ptr_q; + StrideQ dQ; + + const Element* ptr_new_k; + StrideNewK dNewK; + const Element* ptr_new_v; + StrideNewV dNewV; + + Element* ptr_cache_k; + StrideCacheK dCacheK; + Element* ptr_cache_v; + StrideCacheV dCacheV; + }; + + using Params = Arguments; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + return args; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + } + + template + CUTLASS_DEVICE auto constexpr transpose(Tensor const& t) { + CUTE_STATIC_ASSERT_V(rank(t) == _2{}); + return t.compose(make_layout(make_shape(size<1>(t), size<0>(t)), make_stride(size<0>(t), _1{}))); + } + + template< + class CAtom, class TA, class TB, + class CountTensor, class CountLimit, + class SrcTensor, class DstTensor + > + CUTLASS_DEVICE void copy_with_limit( + TiledCopy const& tiled_copy, + CountTensor const& c, CountLimit const& l, + SrcTensor const& src, DstTensor&& dst) { + + //copy(tiled_copy, src, dst); +#if 1 + auto c_f = make_tensor(c.data(), flatten(c.layout())); + auto src_f = make_tensor(src.data(), flatten(src.layout())); + auto dst_f = make_tensor(dst.data(), flatten(dst.layout())); + auto c_v = group_modes<1,rank_v>(c_f); + auto src_v = group_modes<1,rank_v>(src_f); + auto dst_v = group_modes<1,rank_v>(dst_f); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(src_v); i++) { + if (elem_less(c_v(_0{}, i), l)) { + copy(CAtom{}, src_v(_, i), dst_v(_, i)); + } + else { + clear(dst_v(_, i)); + } + } +#endif + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + mask_tile_count *= 2; + + int warp_idx = (threadIdx.x / 32) % 2; + int thread_idx = warp_idx * 32 + (threadIdx.x % 32); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + auto blk_coord_cache = blk_coord; + if (params.cache_batch_idx != nullptr) { + get<2,1>(blk_coord_cache) = params.cache_batch_idx[get<2,1>(blk_coord_cache)]; + } + + // Q1, K1, K2, V1, K3, V2, ... Kn, Vn-1, Vn + // two pipes: Q and KV + auto cQ = make_identity_tensor(select<0,2>(TileShape{})); + auto mQ = make_tensor(make_gmem_ptr(params.ptr_q), append<3>(select<0,2>(TileShapeQK{}), get<3>(problem_shape)), params.dQ); + auto gQ = mQ(_, _, get<2>(blk_coord)); + auto sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + auto tSgQ = thr_mma_qk.partition_A(gQ); + auto tScQ = thr_mma_qk.partition_A(cQ); + + auto atom_q_tv = Layout, _16>, Stride, _1>>{}; + auto atom_kv_tv = Layout, _16>, Stride, _1>>{}; + + auto tiled_copy_q = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_q_tv, + make_layout(shape(tSgQ), replace<0>(stride(tSgQ), replace<0>(stride<0>(tSgQ), get<2>(TileShape{}))))); + + auto thr_copy_q = tiled_copy_q.get_slice(thread_idx); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQ = thr_copy_q.partition_S(tSgQ); + auto tQcQ = thr_copy_q.partition_S(tScQ); + + auto limitQ = append<2>(get<0>(problem_shape), _128{}); + + // Q1 + int q0_index = get<0>(blk_coord); + + auto load_q = [&](int q_index, auto& state) { + pipeline_q.producer_acquire(state); + + // q is always loaded masked + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tQgQ(_, _, _, _)); + auto dst = recast(tQsQ(_, _, _, _, state.index())); + auto c = tQcQ(_, _, _, _); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitQ); + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Always>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_q.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + }; + + load_q(q0_index, pipeline_q_producer_state); + ++pipeline_q_producer_state; + + auto cK_t = make_identity_tensor(select<1,2>(TileShapeQK{})); + auto cK = make_tensor(cK_t.data(), make_layout(get<0>(cK_t.layout()), get<1>(cK_t.layout()), make_layout(_2{}, get<1>(TileShapeQK{}) * stride<0>(cK_t)))); + auto mK = make_tensor(make_gmem_ptr(params.ptr_cache_k), select<1,2,3>(problem_shape), params.dCacheK); + auto gK = local_tile(mK(_, _, get<2>(blk_coord_cache)), TileShapeQK{}, make_coord(_, _, _0{}), Step{}); + auto sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + auto tSgK = thr_mma_qk.partition_B(gK); + auto tScK = thr_mma_qk.partition_B(cK); + + auto tSlK = thr_mma_qk.partition_B(make_tensor((Element*) nullptr, make_ordered_layout(select<1,2>(TileShapeQK{}), Step<_1, _0>{}))); + auto tiled_copy_k = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tSlK.layout()); + + auto thr_copy_k = tiled_copy_k.get_slice(thread_idx); + + auto tKsK = thr_copy_k.partition_D(sK); + auto tKgK = thr_copy_k.partition_S(tSgK); + auto tKcK = thr_copy_k.partition_S(tScK); + + int seqlen_cache_kv = get<1>(problem_shape) - ((params.ptr_new_k != nullptr) ? 1 : 0); + auto limitK = append<2>(seqlen_cache_kv, _128{}); + + auto cV_t = make_identity_tensor(select<1,2>(TileShapePV{})); + auto cV = make_tensor(cV_t.data(), make_layout(get<0>(cV_t.layout()), get<1>(cV_t.layout()), make_layout(_2{}, get<2>(TileShapePV{}) * stride<1>(cV_t)))); + auto mV = make_tensor(make_gmem_ptr(params.ptr_cache_v), select<2,1,3>(problem_shape), select<1,0,2>(params.dCacheV)); + auto gV = local_tile(mV(_, _, get<2>(blk_coord_cache)), TileShapePV{}, make_coord(_, _0{}, _), Step{}); + auto sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + typename CollectiveMmaPV::TiledMma mma_pv; + ThrMMA thr_mma_pv = mma_pv.get_slice(0); + auto tOgV = thr_mma_pv.partition_B(gV); + auto tOcV = thr_mma_pv.partition_B(cV); + auto tOlV = thr_mma_pv.partition_B(make_tensor((Element*) nullptr, make_layout(select<1,2>(TileShapePV{})))); + + auto tiled_copy_v = make_cotiled_copy( + Copy_Atom, Element>{}, + atom_kv_tv, + tOlV.layout()); + + auto thr_copy_v = tiled_copy_v.get_slice(thread_idx); + + auto tVsV = thr_copy_v.partition_D(sV); + auto tVgV = thr_copy_v.partition_S(tOgV); + auto tVcV = thr_copy_v.partition_S(tOcV); + + auto limitV = select<1,0>(limitK); + + int full_tiles_cache = seqlen_cache_kv / get<1>(TileShapeQK{}); + + bool has_new = params.ptr_new_k != nullptr; + Tensor mNewK = make_tensor(make_gmem_ptr(params.ptr_new_k), select<1,2,3>(problem_shape), params.dNewK); + Tensor mNewV = make_tensor(make_gmem_ptr(params.ptr_new_v), select<1,2,3>(problem_shape), params.dNewV); + Tensor gNewK = mNewK(_, _, get<2>(blk_coord)); + Tensor gNewV = mNewV(_, _, get<2>(blk_coord)); + + auto load_k = [&](int k_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (k_index < full_tiles_cache) { + copy(tiled_copy_k, tKgK(_, _, _, _, k_index), tKsK(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tKgK(_, _, _, _, k_index)); + auto dst = recast(tKsK(_, _, _, _, state.index())); + auto src2 = recast(gNewK); + auto c = tKcK(_, _, _, _, k_index); + int vlen = sizeof(Vec) / sizeof(Element); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitK); + if (get<0>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<1>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + auto load_v = [&](int v_index, auto& state) { + pipeline_kv.producer_acquire(state); + + if (v_index < full_tiles_cache) { + copy(tiled_copy_v, tVgV(_, _, _, _, v_index), tVsV(_, _, _, _, state.index())); + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } else { + using Vec = uint128_t; + Vec vzero = uint128_t(0, 0); + auto src = recast(tVgV(_, _, _, _, v_index)); + auto dst = recast(tVsV(_, _, _, _, state.index())); + auto src2 = recast(gNewV); + int vlen = sizeof(Vec) / sizeof(Element); + auto c = tVcV(_, _, _, _, v_index); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(src); i++) { + auto cc = c(vlen*i); + Vec* dst_ptr = &dst(i); + const Vec* src_ptr = &src(i); + bool guard = elem_less(cc, limitV); + if (get<1>(cc) == seqlen_cache_kv && has_new) { + src_ptr = &src2(_0{}, get<0>(cc) / vlen); + guard = true; + } + cutlass::arch::cp_async_zfill<16, cutlass::arch::CacheOperation::Global>( + dst_ptr, src_ptr, guard + ); + } + + pipeline_kv.producer_commit(state, cutlass::arch::cpasync_barrier_arrive); + } + }; + + // K1 + int k_index = 0; + int v_index = 0; + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + mask_tile_count -= 1; + + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + load_k(k_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + k_index += 1; + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + } + + // V1 + + load_v(v_index, pipeline_kv_producer_state); + + ++pipeline_kv_producer_state; + v_index += 1; + + if (has_new) { + for (int i = thread_idx; i < get<2>(TileShape{}); i += 64) { + gK(seqlen_cache_kv, i, 0) = gNewK(0, i); + gV(i, seqlen_cache_kv, 0) = gNewV(0, i); + } + } + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3606dcc7d9a0e9b764a4f361daed83da3715d67a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_load_tma_warpspecialized.hpp @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape +> +struct Sm100FmhaLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); + } + } else { + problem_shape_qk = problem_shape; + } + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape)); + + int q_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape)); + + int kv_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // V1 + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index())); + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf41af9f533ad7e5099d5ae483c93605b44c71f5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_fwd_mainloop_tma_warpspecialized.hpp @@ -0,0 +1,1225 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/sm100_fmha_mla_load_tma_warpspecialized.hpp" +#include "common/pipeline_mla.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element_, + class ElementQK_, + class ElementPV_, + class ComposedTileShape_, + class StrideQ_, + class StrideK_, + class StrideV_, + class Mask_, + // shape here is QG K H + // and referes to the two softmax warps + // (2, 1, 1) means that they are stacked (best for large Q since it loads the least K/V) + // (1, 2, 1) means they sit side by side (best for small Q / large K) + class ThreadShape = Shape<_2, _1, _1>, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdMainloopTmaWarpspecialized { + + using Element = Element_; + using ElementQK = ElementQK_; + using ElementPV = ElementPV_; + using ComposedTileShape = ComposedTileShape_; + using StrideQ = StrideQ_; + using StrideK = StrideK_; + using StrideV = StrideV_; + using Mask = Mask_; + + static constexpr int StageCountQ = 2; + static constexpr int StageCountK = 1; + static constexpr int StageCountV = 1; + static constexpr int StageCountKV = StageCountK + StageCountV; + // Support StageCountKV > 2 in the future. + static_assert(StageCountK == 1 && StageCountV == 1, "Only support StageCountK = StageCountV = 1!"); + static_assert(std::is_same_v>, "Only support ThreadShape = Shape<_2, _1, _1>"); + + using ClusterShape = Shape<_1, _1, _1>; + + static const int Alignment = 128 / sizeof_bits_v; + + static constexpr auto HeadDimLatent = size<2, 0>(ComposedTileShape{}); + static constexpr auto HeadDimRope = size<2, 1>(ComposedTileShape{}); + static constexpr auto HeadDimQK = HeadDimLatent + HeadDimRope; + static constexpr auto HeadDimPV = HeadDimLatent; + + using TileShapeQK = decltype(shape_div(replace<2>(ComposedTileShape{}, HeadDimQK), ThreadShape{})); + using TileShapePV = decltype(select<0,2,1>(shape_div(replace<2>(ComposedTileShape{}, HeadDimPV), ThreadShape{}))); + using TileShape = decltype(replace<2>(ComposedTileShape{}, HeadDimLatent)); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, StrideQ, Alignment, + Element, StrideK, Alignment, + ElementQK, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, StrideK, Alignment, + Element, decltype(select<1,0,2>(StrideV{})), Alignment, + ElementPV, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount<3> /* we change it later anyways*/, + cutlass::gemm::KernelTmaWarpSpecialized1SmSm100>::CollectiveOp; + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutB{}, Int{})); + using SmemLayoutV = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutB{}, Int{})); + + using SmemStorageOneStageO = decltype(make_layout(replace<2>(TileShapePV{}, _1{}))); + + // Since the shared memory is not sufficient if we use separate Q, K, V, and O shared memory, + // we reuse shared memory for V and O to address this problem, + // and a barrier has been added to coordinate access to shared memory. + static constexpr bool IsOrderLoadEpilogue = std::is_same_v; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct TensorStorageQKVO { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_o; // use as O0 + cute::array_aligned> smem_v; // use as V0 and O1 + }; + + struct TensorStorageQKV { + cute::array_aligned> smem_q; + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + + using TensorStorage = std::conditional_t; + + enum class TmemAllocation : uint32_t { + kSizeS = 128, + kSizeO = 128, + kSizeP = 32, + S0 = 0, + S1 = S0 + kSizeS, + V0 = S0, // stats storage from softmax to correction + V1 = S1, + P0 = S0 + kSizeP, + P1 = S1 + kSizeP, + O0 = S1 + kSizeS, + O1 = O0 + kSizeO, + kEnd = O1 + kSizeO + }; + + // indices for V0 / V1 + enum : int { + kIdxOldRowMax = 0, + kIdxNewRowMax = 1, + kIdxFinalRowSum = 0, + kIdxFinalRowMax = 1 + }; + + // from load to mma warp, protects q in smem + using PipelineQ = cutlass::PipelineTmaUmmaAsync< + StageCountQ, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from load to mma warp, protects k/v in smem + using PipelineKV = cutlass::PipelineTmaAsyncMla< + StageCountKV, + typename CollectiveMmaQK::AtomThrShapeMNK + >; + + // from mma to softmax0/1 warp, protects S in tmem + // (not sure yet about the reverse direction) + // there is one pipe per softmax warp, and the mma warp alternates between them + using PipelineS = cutlass::PipelineUmmaAsync<1>; + + // from softmax0/1/ to correction wg + using PipelineC = cutlass::PipelineAsync<1>; + + // from mma to correction + using PipelineO = cutlass::PipelineUmmaAsync<2>; + + // from corr to epilogue + using PipelineE = cutlass::PipelineAsync<2>; + + using OrderBarrierSoftmax = cutlass::OrderedSequenceBarrier< + /*stages*/ 1, /*groups*/ 2>; + + static constexpr int TransactionBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + using Load = Sm100MlaFwdLoadTmaWarpspecialized< + Element, StrideQ, StrideK, StrideV, + CollectiveMmaQK, CollectiveMmaPV, + SmemLayoutQ, SmemLayoutK, SmemLayoutV, + TensorStorage, PipelineQ, PipelineKV, Mask, TileShape, OrderLoadEpilogue + >; + + struct Arguments { + typename Load::Arguments load; + + // if zero, defaults to 1/sqrt(D) + float scale_softmax = 0.0f; + + // scaling factors to dequantize QKV + float scale_q = 1.0f; + float scale_k = 1.0f; + float scale_v = 1.0f; + + // scaling factor to quantize O + float inv_scale_o = 1.0f; + }; + + struct Params { + typename Load::Params load; + + float scale_softmax; + float scale_softmax_log2; + + float scale_output; + }; + + template + static bool can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + float scale_softmax = args.scale_softmax; + if (scale_softmax == 0.0f) { + scale_softmax = 1.0f / (float) std::sqrt(get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + } + float log2_e = static_cast(std::log2(std::exp(1.0))); + + return Params{ + Load::to_underlying_arguments(problem_shape, args.load, workspace), + args.scale_q * args.scale_k * scale_softmax, + args.scale_q * args.scale_k * log2_e * scale_softmax, + args.scale_v * args.inv_scale_o + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + Load::prefetch_tma_descriptors(params.load); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + Load load; + load.load(blk_coord, problem_shape, params.load, params_problem_shape, + storage, + pipeline_q, pipeline_q_producer_state, + pipeline_kv, pipeline_kv_producer_state); + } + + template + CUTLASS_DEVICE auto + mma( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_consumer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_consumer_state, + PipelineS& pipeline_s0, typename PipelineS::PipelineState& pipeline_s0_producer_state, + PipelineS& pipeline_s1, typename PipelineS::PipelineState& pipeline_s1_producer_state, + PipelineO& pipeline_corr, typename PipelineO::PipelineState& pipeline_corr_producer_state) { + + auto pipeline_q_release_state = pipeline_q_consumer_state; + auto pipeline_kv_release_state = pipeline_kv_consumer_state; + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + typename CollectiveMmaQK::TiledMma mma_qk; + ThrMMA thr_mma_qk = mma_qk.get_slice(0); + + typename CollectiveMmaPV::TiledMma mma_pv; + TiledMMA mma_pv_ts = to_tiled_mma_sm100_ts(mma_pv); + ThrMMA thr_mma_pv = mma_pv_ts.get_slice(0); + + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tSrQ = thr_mma_qk.make_fragment_A(sQ); + Tensor tSrK = thr_mma_qk.make_fragment_B(sK); + Tensor tOrV = thr_mma_pv.make_fragment_B(sV); + + // tmem layout is + // S0 S1`O0 O1 + // sequential in memory, where S overlaps with P and V + + Tensor tStS = partition_fragment_C(mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(mma_pv_ts, select<0,1>(TileShapePV{})); + + Tensor tStS0 = tStS; + tStS0.data() = tStS.data().get() + uint32_t(TmemAllocation::S0); + Tensor tStS1 = tStS; + tStS1.data() = tStS.data().get() + uint32_t(TmemAllocation::S1); + + Tensor tOtO0 = tOtO; + tOtO0.data() = tOtO.data().get() + uint32_t(TmemAllocation::O0); + Tensor tOtO1 = tOtO; + tOtO1.data() = tOtO.data().get() + uint32_t(TmemAllocation::O1); + + Tensor sP = make_tensor(make_smem_ptr((Element*)nullptr), typename CollectiveMmaPV::SmemLayoutA{}); + Tensor tOrP = thr_mma_pv.make_fragment_A(sP)(_, _, _, _0{}); // slice out staging + + Tensor tOrP0 = tOrP; + tOrP0.data() = tOrP0.data().get() + uint32_t(TmemAllocation::P0); + Tensor tOrP1 = tOrP; + tOrP1.data() = tOrP1.data().get() + uint32_t(TmemAllocation::P1); + + int k_index = 0; + int v_index = 0; + int q_index = 0; + + // wait for Q1 + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + + Tensor tSrQ0 = tSrQ(_,_,_,q_index); + + + // wait for K1 + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * K1 -> S1 + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + // release K1 + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // wait for Q2 + if constexpr (get<0>(ThreadShape{}) > 1 || get<2>(ThreadShape{}) > 1) { + q_index = pipeline_q_consumer_state.index(); + pipeline_q.consumer_wait(pipeline_q_consumer_state); + ++pipeline_q_consumer_state; + } + + Tensor tSrQ1 = tSrQ(_,_,_,q_index); + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + // gemm Q2 * K1 -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release K1 + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for V1 + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // this acquire returns the ownership of all of S0 to the mma warp + // including the P0 part + // acquire corr first to take it out of the critical + // path since softmax takes longer + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + // gemm P1 * V1 -> O1 + gemm_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + mma_pv_ts.accumulate_ = UMMA::ScaleOut::Zero; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // wait for Ki + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm Q1 * Ki -> S1 + gemm_zero_acc(mma_qk, tSrQ0, tSrK(_,_,_,k_index / 2), tStS0); + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + + // gemm P2 * V(i-1) -> O2 + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release V(i-1) + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + k_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm Q2 * Ki -> S2 + gemm_zero_acc(mma_qk, tSrQ1, tSrK(_,_,_,k_index / 2), tStS1); + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // release Ki + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + // wait for Vi + v_index = (pipeline_kv_consumer_state.index()); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + + // gemm P1 * Vi -> O1 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + + pipeline_s0.producer_acquire(pipeline_s0_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP0, tOrV(_,_,_,v_index / 2), tOtO0); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + if constexpr (get<1>(ThreadShape{}) > 1) { + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + } + } + + // release Q1 + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + + // release Q2 + if constexpr (get<0>(ThreadShape{}) > 1) { + pipeline_q.consumer_release(pipeline_q_release_state); + ++pipeline_q_release_state; + } + + // wait for Vi + if constexpr (get<1>(ThreadShape{}) > 1) { + v_index = pipeline_kv_consumer_state.index(); + pipeline_kv.consumer_wait(pipeline_kv_consumer_state); + ++pipeline_kv_consumer_state; + } + + // gemm P2 * Vi -> O2 + pipeline_corr.producer_acquire(pipeline_corr_producer_state); + pipeline_s1.producer_acquire(pipeline_s1_producer_state); + + gemm_reset_zero_acc(mma_pv_ts, tOrP1, tOrV(_,_,_,v_index / 2), tOtO1); + + pipeline_corr.producer_commit(pipeline_corr_producer_state); + ++pipeline_corr_producer_state; + + // release Vi + pipeline_kv.consumer_release(pipeline_kv_release_state); + ++pipeline_kv_release_state; + + pipeline_s0.producer_commit(pipeline_s0_producer_state); + ++pipeline_s0_producer_state; + + pipeline_s1.producer_commit(pipeline_s1_producer_state); + ++pipeline_s1_producer_state; + + // T0 S00 B1, T0 S10 B1, T0 S00 B2, T0 S01 B1, T0 S10 B2, T0 S11 B1, T0 S01 B2, T1 S00 B1, T0 S11 B2, ... + // Q1 * K1 , Q2 * K1 , S11 * V1 , Q1 * K2 , S21 * V1 , Q2 * K2 , S12 * V2 , Q1 * K3 , S22 * K2 , ... + } + + template + CUTLASS_DEVICE auto + softmax_step( + bool need_apply_mask, + float& row_max, float& row_sum, + Stage stage, bool final_call, + BlkCoord const& blk_coord, CoordTensor const& cS, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + tStS.data() = uint32_t(stage == _0{} ? TmemAllocation::S0 : TmemAllocation::S1); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + tStS_v.data() = uint32_t(stage == _0{} ? TmemAllocation::V0 : TmemAllocation::V1); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + auto tilePlikeFP32 = size<1>(TileShapeQK{}) / Int{} * Int{}; + Tensor tStS_P = tStS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + tStS_P.data() = warp_uniform(uint32_t(stage == _0{} ? TmemAllocation::P0 : TmemAllocation::P1)); + Tensor tScS_P = tScS.compose(make_layout(make_shape(_128{}, tilePlikeFP32))); + + // Each thread owns a single row + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b32x; // 4x32 threads with 128 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b32x; // 4x32 threads with 128 cols of 8b elem + using TMEM_STORE_V = SM100_TMEM_STORE_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tStS); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtS = thr_tmem_load.partition_S(tStS); + Tensor tTMEM_LOADcS = thr_tmem_load.partition_D(tScS); + + auto tiled_tmem_storev = make_tmem_copy(TMEM_STORE_V{}, tStS_v); + auto thr_tmem_storev = tiled_tmem_storev.get_slice(thread_idx); + + Tensor tTMEM_STOREVtS = thr_tmem_storev.partition_D(tStS_v); + Tensor tTMEM_STOREVcS = thr_tmem_storev.partition_S(tScS_v); + + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tStS_P); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P); + tTMEM_STOREtS_x4.data() = warp_uniform(tTMEM_STOREtS_x4.data().get()); + Tensor tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P); + + // wait on tensor core pipe + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + // read all of S from tmem into reg mem + Tensor tTMEM_LOADrS = make_tensor(shape(tTMEM_LOADcS)); + copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS); + + if constexpr (need_mask) { + if(need_apply_mask) { + Mask{}.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, problem_shape); + } + } + + ElementQK old_row_max = row_max; + { + // compute rowmax + float row_max_0 = row_max; + float row_max_1 = row_max; + float row_max_2 = row_max; + float row_max_3 = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 4) { + row_max_0 = ::fmax(row_max_0, tTMEM_LOADrS(i)); + row_max_1 = ::fmax(row_max_1, tTMEM_LOADrS(i+1)); + row_max_2 = ::fmax(row_max_2, tTMEM_LOADrS(i+2)); + row_max_3 = ::fmax(row_max_3, tTMEM_LOADrS(i+3)); + } + row_max = ::fmax(row_max_0, row_max_1); + row_max = ::fmax(row_max, row_max_2); + row_max = ::fmax(row_max, row_max_3); + } + + ElementQK row_max_safe = row_max == -INFINITY ? 0 : row_max; + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxOldRowMax) = old_row_max; + tTMEM_STOREVrS(kIdxNewRowMax) = row_max_safe; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + // notify correction wg that they are ready (might need addtl ordering between S0 and S1 WG's) + + ElementQK scale = params.scale_softmax_log2; + ElementQK row_max_scale = row_max_safe * scale; + + float2 scale_fp32x2 = make_float2(scale, scale); + float2 minus_row_max_scale_fp32x2 = make_float2(-row_max_scale, -row_max_scale); + + Tensor tTMEM_STORErS_x4 = make_tensor(shape(tTMEM_STOREcS)); + + constexpr int kConversionsPerStep = 2; + + Tensor tTMEM_STORErS_x4_e = recast>(tTMEM_STORErS_x4); + + NumericArrayConverter convert; + + constexpr int kReleasePipeCount = 10; // must be multiple of 2 + + order_s.wait(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 2) { + float2 in = make_float2( + tTMEM_LOADrS(i + 0), + tTMEM_LOADrS(i + 1) + ); + float2 out; + cute::fma(out, scale_fp32x2, in, minus_row_max_scale_fp32x2); + tTMEM_LOADrS(i + 0) = out.x; + tTMEM_LOADrS(i + 1) = out.y; + + tTMEM_LOADrS(i+0) = ::exp2f(tTMEM_LOADrS(i+0)); + tTMEM_LOADrS(i+1) = ::exp2f(tTMEM_LOADrS(i+1)); + + Array in_conv; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kConversionsPerStep; j++) { + in_conv[j] = tTMEM_LOADrS(i + j); + } + tTMEM_STORErS_x4_e[i / kConversionsPerStep] = convert(in_conv); + + + if (i == size(tTMEM_LOADrS) - kReleasePipeCount) { + order_s.arrive(); + } + + // this prevents register spills in fp16 + if constexpr (size<2>(tTMEM_STORErS_x4) == _2{}) { + if (i == size(tTMEM_LOADrS) - 6) { + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, 0), tTMEM_STOREtS_x4(_, _, 0)); + } + } + } + + // tmem_store(reg_S8) -> op_P + CUTE_STATIC_ASSERT_V(size<2>(tTMEM_STORErS_x4) <= _2{}); + CUTE_STATIC_ASSERT_V(size<1>(tTMEM_STORErS_x4) == _1{}); + copy(tiled_tmem_store, tTMEM_STORErS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1), tTMEM_STOREtS_x4(_, _, size<2>(tTMEM_STORErS_x4) - 1)); + + cutlass::arch::fence_view_async_tmem_store(); + + // notify tensor core warp that P is ready + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + ElementQK acc_scale = 0.5f * ::exp2f(scale * (old_row_max - row_max_safe)); + row_sum *= acc_scale; + // row_sum = sum(reg_S) + float2 local_row_sum_f32x2 = make_float2(row_sum, row_sum); + float2 local_row_sum_1 = make_float2(0, 0); + float2 local_row_sum_2 = make_float2(0, 0); + float2 local_row_sum_3 = make_float2(0, 0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTMEM_LOADrS); i += 8) { + // row_sum += tTMEM_LOADrS(i); + float2 in = make_float2(tTMEM_LOADrS(i), tTMEM_LOADrS(i+1)); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, in); + + in = make_float2(tTMEM_LOADrS(i+2), tTMEM_LOADrS(i+2+1)); + cute::add(local_row_sum_1, local_row_sum_1, in); + + in = make_float2(tTMEM_LOADrS(i+4), tTMEM_LOADrS(i+4+1)); + cute::add(local_row_sum_2, local_row_sum_2, in); + + in = make_float2(tTMEM_LOADrS(i+6), tTMEM_LOADrS(i+6+1)); + cute::add(local_row_sum_3, local_row_sum_3, in); + } + + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_1); + cute::add(local_row_sum_2, local_row_sum_2, local_row_sum_3); + cute::add(local_row_sum_f32x2, local_row_sum_f32x2, local_row_sum_2); + float local_row_sum = local_row_sum_f32x2.x + local_row_sum_f32x2.y; + + row_sum = local_row_sum; + + if (final_call) { + // re-acquire the S part in the final step + pipeline_s.consumer_wait(pipeline_s_consumer_state); + + Tensor tTMEM_STOREVrS = make_tensor(shape(tTMEM_STOREVcS)); + tTMEM_STOREVrS(kIdxFinalRowMax) = row_max; + tTMEM_STOREVrS(kIdxFinalRowSum) = row_sum; + copy(tiled_tmem_storev, tTMEM_STOREVrS, tTMEM_STOREVtS); + } + } + + template + CUTLASS_DEVICE auto + softmax( + Stage stage, + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + PipelineS& pipeline_s, typename PipelineS::PipelineState& pipeline_s_consumer_state, + PipelineC& pipeline_c, typename PipelineC::PipelineState& pipeline_c_producer_state, + OrderBarrierSoftmax& order_s) { + const int mask_trip_count = Mask{}.get_masked_trip_count(blk_coord, TileShape{}, problem_shape); + const int total_trip_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + int trip_idx = total_trip_count; + + ElementQK row_max = -INFINITY; + ElementQK row_sum = 0; + + Tensor cS_base = make_identity_tensor(select<0,1>(TileShapeQK{})); + auto logical_offset = make_coord( + get<0>(blk_coord) * get<0>(TileShape{}) + (stage % get<0>(ThreadShape{})) * get<0>(TileShapeQK{}), + 0 + (stage % get<1>(ThreadShape{})) * get<1>(TileShapeQK{}) + ); + Tensor cS = domain_offset(logical_offset, cS_base); + + pipeline_c.producer_acquire(pipeline_c_producer_state); + + constexpr bool NeedMask = !std::is_same_v; + + CUTLASS_PRAGMA_NO_UNROLL + for (; trip_idx > 0; trip_idx -= 1) { + softmax_step( + trip_idx <= mask_trip_count, + row_max, row_sum, stage, + trip_idx == 1, + blk_coord, cS, params, problem_shape, + pipeline_s, pipeline_s_consumer_state, + pipeline_c, pipeline_c_producer_state, + order_s + ); + + cS.data() = cS.data() + E<1>{} * get<1>(ThreadShape{}) * get<1>(TileShapeQK{}); + } + + pipeline_c.producer_commit(pipeline_c_producer_state); + ++pipeline_c_producer_state; + + pipeline_c.producer_acquire(pipeline_c_producer_state); + // empty step to sync against pipe s + pipeline_s.consumer_release(pipeline_s_consumer_state); + ++pipeline_s_consumer_state; + } + + template + CUTLASS_DEVICE auto + correction_epilogue( + float scale, + Stage stage, + TensorO const& sO_01) { + + using ElementOut = typename TensorO::value_type; + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor sO = sO_01(_,_,stage); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 32 / sizeof(ElementOut); + + using TMEM_LOAD = std::conditional_t; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + Tensor tOsO = mma.get_slice(0).partition_C(sO); + + Tensor tOtO_i = logical_divide(tOtO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = logical_divide(tOcO, make_layout(make_shape(_128{}, Int{}))); + Tensor tOsO_i = logical_divide(tOsO, make_layout(make_shape(_128{}, Int{}))); + + if constexpr (decltype(stage == _0{})::value) { + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O0); + } + else { + static_assert(decltype(stage == _1{})::value, "stage is either 0 or 1"); + tOtO_i.data() = tOtO_i.data().get() + uint32_t(TmemAllocation::O1); + } + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i(make_coord(_, _), _0{})); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i(make_coord(_, _), _)); + Tensor tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i(make_coord(_, _), _)); + + float2 scale_f32x2 = make_float2(scale, scale); + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < get<2>(TileShape{}) / kCorrectionTileSize; i++) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO(_, _0{}, _0{}, i); + Tensor tTMEM_LOADsO_i = tTMEM_LOADsO(_, _0{}, _0{}, i); + + Tensor tTMrO = make_tensor(shape(tTMEM_LOADcO(_, _0{}, _0{}, i))); + + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO); + +#ifndef ONLY_SOFTMAX + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO); j += 2) { + float2 in = make_float2(tTMrO(j), tTMrO(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO(j) = out.x; + tTMrO(j+1) = out.y; + } +#endif + + constexpr int N = 4 / sizeof(ElementOut); + NumericArrayConverter convert; + + Tensor tSMrO = make_tensor_like(tTMrO); + + Tensor tCs = recast(tTMrO); + Tensor tCd = recast(tSMrO); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tCs); j++) { + tCd(j) = convert.convert(tCs(j)); + } + + Tensor tSMsO_i = recast(tTMEM_LOADsO_i); + Tensor tSMrO_i = recast(tSMrO); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tSMrO_i, tSMsO_i); + } + + cutlass::arch::fence_view_async_shared(); + } + + CUTLASS_DEVICE auto + correction_rescale( + float scale, + uint32_t tmem_O) { + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + // As opposed to the softmax, we do not have enough registers here + // to load all of the values (for tile kv = 128), so we loop + // good values would be either 32 or 64 + constexpr int kCorrectionTileSize = 16; + + using TMEM_LOAD = SM100_TMEM_LOAD_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + using TMEM_STORE = SM100_TMEM_STORE_32dp32b16x; // 4x32 threads with 64 cols of 32b elem + + typename CollectiveMmaPV::TiledMma mma; + Tensor cO = make_identity_tensor(select<0,1>(TileShapePV{})); + Tensor tOtO = partition_fragment_C(mma, select<0,1>(TileShapePV{})); + Tensor tOcO = mma.get_slice(0).partition_C(cO); + + Tensor tOtO_i = tOtO.compose(make_layout(make_shape(_128{}, Int{}))); + Tensor tOcO_i = tOcO.compose(make_layout(make_shape(_128{}, Int{}))); + + tOtO_i.data() = tOtO_i.data().get() + tmem_O; + + auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD{}, tOtO_i); + auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + auto tiled_tmem_store = make_tmem_copy(TMEM_STORE{}, tOtO_i); + auto thr_tmem_store = tiled_tmem_store.get_slice(thread_idx); + + Tensor tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i); + Tensor tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i); + Tensor tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i); + Tensor tTMEM_STOREcO = thr_tmem_store.partition_S(tOcO_i); + static_assert(shape(tTMEM_STOREcO) == shape(tTMEM_LOADcO)); + + float2 scale_f32x2 = make_float2(scale, scale); + + Tensor tTMrO = make_tensor(make_shape(shape(tTMEM_LOADcO), Int<128 / kCorrectionTileSize>{})); + + auto copy_in = [&](int i) { + Tensor tTMEM_LOADtO_i = tTMEM_LOADtO; + tTMEM_LOADtO_i.data() = tTMEM_LOADtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i); + }; + + auto copy_out = [&](int i) { + Tensor tTMEM_STOREtO_i = tTMEM_STOREtO; + tTMEM_STOREtO_i.data() = tTMEM_STOREtO_i.data().get() + uint32_t(i * kCorrectionTileSize); + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i); + }; + + // sequence: LLMSLMSLMSS + + // loop: + // TMEM_LOAD, FMUL2 scale, TMEM_STORE + copy_in(0); + + constexpr int count = get<2>(TileShape{}) / kCorrectionTileSize; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < count; i++) { + if (i != count - 1) { + copy_in(i+1); + } + + Tensor tTMrO_i = tTMrO(_, i).compose(make_layout(shape<0>(tTMrO))); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(tTMrO_i); j += 2) { + float2 in = make_float2(tTMrO_i(j), tTMrO_i(j+1)); + float2 out; + cute::mul(out, scale_f32x2, in); + tTMrO_i(j) = out.x; + tTMrO_i(j+1) = out.y; + } + + copy_out(i); + } + } + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineC& pipeline_s0_c, typename PipelineC::PipelineState& pipeline_s0_c_consumer_state, + PipelineC& pipeline_s1_c, typename PipelineC::PipelineState& pipeline_s1_c_consumer_state, + PipelineO& pipeline_o, typename PipelineO::PipelineState& pipeline_o_consumer_state, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + int mask_tile_count = Mask{}.get_trip_count(blk_coord, TileShape{}, problem_shape); + + int thread_idx = threadIdx.x % (4 * cutlass::NumThreadsPerWarp); + + Tensor tStS = partition_fragment_C(typename CollectiveMmaQK::TiledMma{}, select<0,1>(TileShapeQK{})); + + Tensor cS = make_identity_tensor(select<0,1>(TileShapeQK{})); + Tensor tScS = typename CollectiveMmaQK::TiledMma{}.get_slice(0).partition_C(cS); + + Tensor tStS_v = tStS.compose(make_layout(make_shape(_128{}, _2{}))); + Tensor tScS_v = tScS.compose(make_layout(make_shape(_128{}, _2{}))); + + using TMEM_LOAD_V = SM100_TMEM_LOAD_32dp32b2x; // 4x32 threads with 2 cols of 32b elem + + auto tiled_tmem_loadv = make_tmem_copy(TMEM_LOAD_V{}, tStS_v); + auto thr_tmem_loadv = tiled_tmem_loadv.get_slice(thread_idx); + + Tensor tTMEM_LOADVtS = thr_tmem_loadv.partition_S(tStS_v); + Tensor tTMEM_LOADVcS = thr_tmem_loadv.partition_D(tScS_v); + + Tensor tTMEM_LOADVtS0 = tTMEM_LOADVtS; + tTMEM_LOADVtS0.data() = tTMEM_LOADVtS0.data().get() + uint32_t(TmemAllocation::V0); + Tensor tTMEM_LOADVtS1 = tTMEM_LOADVtS; + tTMEM_LOADVtS1.data() = tTMEM_LOADVtS1.data().get() + uint32_t(TmemAllocation::V1); + + // ignore first signal from softmax as no correction is required + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // handle the last iteration differently (i.e. tmem_load/stsm for epi) + mask_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + + // read row_wise new global max + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + // e^(scale * (old_max - new_max) + float scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O0)); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + scale = ::exp2f(params.scale_softmax_log2 * (tTMEM_LOADVrS(kIdxOldRowMax) - tTMEM_LOADVrS(kIdxNewRowMax))); + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + + correction_rescale(scale, uint32_t(TmemAllocation::O1)); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + cutlass::arch::fence_view_async_tmem_store(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + } + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + // do the final correction to O1 + // better to somehow special-case it in the loop above + // doesn't matter for non-persistent code, but if it were + // persistent we do not want to release O too early + + pipeline_s0_c.consumer_wait(pipeline_s0_c_consumer_state); + + // read from V0 + // read row_sum and final row_max here + Tensor tTMEM_LOADVrS = make_tensor(shape(tTMEM_LOADVcS)); + copy(tiled_tmem_loadv, tTMEM_LOADVtS0, tTMEM_LOADVrS); + + pipeline_s0_c.consumer_release(pipeline_s0_c_consumer_state); + ++pipeline_s0_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + // store to epi smem + + // loop: + // TMEM_LOAD + // FMUL2 scale = 1 / global_sum * out_quant_scale + // F2FP + // store to smem + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _0{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + pipeline_s1_c.consumer_wait(pipeline_s1_c_consumer_state); + + // load from V1 + copy(tiled_tmem_loadv, tTMEM_LOADVtS1, tTMEM_LOADVrS); + + pipeline_s1_c.consumer_release(pipeline_s1_c_consumer_state); + ++pipeline_s1_c_consumer_state; + + pipeline_o.consumer_wait(pipeline_o_consumer_state); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + correction_epilogue(params.scale_output / tTMEM_LOADVrS(kIdxFinalRowSum), _1{}, sO); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = get<0>(tTMEM_LOADVcS(_0{})) + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + ElementPV lse = cutlass::fast_log(tTMEM_LOADVrS(kIdxFinalRowSum)) + params.scale_softmax * tTMEM_LOADVrS(kIdxFinalRowMax); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_tmem_load(); + + pipeline_o.consumer_release(pipeline_o_consumer_state); + ++pipeline_o_consumer_state; + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + + + template< + class BlkCoord, class ProblemShape, class ParamsProblemShape, + class TensorStorageEpi, class CollectiveEpilogue + > + CUTLASS_DEVICE auto + correction_empty( + BlkCoord const& blk_coord, + Params const& params, ProblemShape const& problem_shape, + ParamsProblemShape const& params_problem_shape, + TensorStorageEpi& shared_storage_epi, + PipelineE& pipeline_epi, typename PipelineE::PipelineState& pipeline_epi_producer_state, + CollectiveEpilogue& epilogue) { + + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + Tensor sO = make_tensor(make_smem_ptr(shared_storage_epi.smem_o.data()), typename TensorStorageEpi::SmemLayoutO{}); + Tensor gLSE = make_tensor(make_gmem_ptr(epilogue.params.ptr_LSE), select<0,3>(problem_shape), epilogue.params.dLSE); + float lse = -INFINITY; + int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); + +#define DSHOW(x) print(#x ": "); print(x); print("\n") + if (threadIdx.x % 128 == 0 && block0()) { + DSHOW(sO); + } +#if 1 + + using ElementOut = typename CollectiveEpilogue::ElementOut; + auto tiled_copy = make_cotiled_copy( + Copy_Atom, ElementOut>{}, + make_ordered_layout(make_shape(_128{}, Int{}), Step<_1, _0>{}), + sO.layout()); + + auto thr_copy = tiled_copy.get_slice(thread_idx); + auto tOgO = thr_copy.partition_D(sO); + auto tOrO = make_tensor(shape(tOgO(_,_,_,_0{}))); + clear(tOrO); + + copy(tiled_copy, tOrO, tOgO(_,_,_,_0{})); +#endif + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + + copy(tiled_copy, tOrO, tOgO(_,_,_,_1{})); + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_acquire(pipeline_epi_producer_state); + + if (epilogue.params.ptr_LSE != nullptr) { + int row_idx = thread_idx + get<0>(TileShape{}) * get<0>(blk_coord) + get<0>(TileShapeQK{}); + + int row_offset = 0; + if constexpr (is_variable_length_v>) { + row_offset = get<0>(params_problem_shape).cumulative_length[get<2,1>(blk_coord)]; + } + + if (row_idx < get<0>(problem_shape)) { + gLSE(row_idx + row_offset, get<2>(blk_coord)) = lse; + } + } + + cutlass::arch::fence_view_async_shared(); + pipeline_epi.producer_commit(pipeline_epi_producer_state); + ++pipeline_epi_producer_state; + } + +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c8fc13b95286fe18e94536204a304eefdabc67ca --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/collective/sm100_fmha_mla_load_tma_warpspecialized.hpp @@ -0,0 +1,323 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cute/tensor.hpp" +#include "cute/layout.hpp" + +#include "collective/fmha_common.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template< + class Element, + class StrideQ, + class StrideK, + class StrideV, + class CollectiveMmaQK, + class CollectiveMmaPV, + class SmemLayoutQ, + class SmemLayoutK, + class SmemLayoutV, + class TensorStorage, + class PipelineQ, + class PipelineKV, + class Mask, + class TileShape, + class OrderLoadEpilogue = cute::false_type +> +struct Sm100MlaFwdLoadTmaWarpspecialized { + + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TileShapePV = typename CollectiveMmaPV::TileShape; + + static constexpr int TransactionBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int TransactionBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + struct Arguments { + const Element* ptr_Q; + StrideQ dQ; + const Element* ptr_K; + StrideK dK; + const Element* ptr_V; + StrideV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + }; + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + void* workspace) { + + auto ptr_Q = args.ptr_Q; + auto ptr_K = args.ptr_K; + auto ptr_V = args.ptr_V; + auto dQ = args.dQ; + auto dK = args.dK; + auto dV = args.dV; + + using IntProblemShape = cute::tuple, int>>; + + IntProblemShape problem_shape_qk; + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(problem_shape).cumulative_length; + auto cumulative_length_k = get<1>(problem_shape).cumulative_length; + if (cumulative_length_q != nullptr && cumulative_length_k != nullptr ) { + get<0>(problem_shape_qk) = get<0>(problem_shape).total_length; + get<1>(problem_shape_qk) = get<1>(problem_shape).total_length; + get<2>(problem_shape_qk) = get<2, 0>(problem_shape) + get<2, 1>(problem_shape); + get<3>(problem_shape_qk) = get<3>(problem_shape); + } + } else { + problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; + } + + auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); + + auto params_qk = CollectiveMmaQK::to_underlying_arguments( + problem_shape_qk, + typename CollectiveMmaQK::Arguments { + ptr_Q, dQ, + ptr_K, dK, + }, /*workspace=*/ nullptr); + + auto params_pv = CollectiveMmaPV::to_underlying_arguments( + problem_shape_pv, + typename CollectiveMmaPV::Arguments { + ptr_K, dK, // never used, dummy + ptr_V, select<1,0,2>(dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b + }; + } + + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load( + BlkCoord const& blk_coord_in, ProblemShape const& problem_shape, + Params const& params, ParamsProblemShape const& params_problem_shape, + TensorStorage& storage, + PipelineQ& pipeline_q, typename PipelineQ::PipelineState& pipeline_q_producer_state, + PipelineKV& pipeline_kv, typename PipelineKV::PipelineState& pipeline_kv_producer_state) { + + BlkCoord blk_coord_q = blk_coord_in; + BlkCoord blk_coord_kv = blk_coord_in; + + auto problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape)); + auto problem_shape_v = replace<2>(problem_shape, get<2, 0>(problem_shape)); + + int mask_tile_count = Mask{}.get_trip_count(blk_coord_in, TileShape{}, problem_shape); + + using X = Underscore; + + // this one is only executed by one thread, no need to elect_one + + // Q1, K1, Q2, V1, K2, V2, K3, V3, ... + // two pipes: Q and KV + // from Memory (prod) to TensorCore (cons) + + // compute gQ, sQ + // we load 2*get<0>(blk_coord), and 2*get<0>(blk_coord) + 1 + ThrMMA mma_qk = typename CollectiveMmaQK::TiledMma{}.get_slice(0); + Tensor mQ_qdl_p = params.tma_load_q.get_tma_tensor(select<0,2,3>(problem_shape_qk)); + + int q_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length_q = get<0>(params_problem_shape).cumulative_length; + if (cumulative_length_q != nullptr) { + q_offs_0 = cumulative_length_q[get<2,1>(blk_coord_q)]; + get<2,1>(blk_coord_q) = 0; + } + } + + Tensor mQ_qdl = domain_offset(make_coord(q_offs_0, _0{}, make_coord(_0{}, _0{})), mQ_qdl_p); + + Tensor gQ_qdl = local_tile(mQ_qdl, TileShapeQK{}, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor tSgQ_qdl = mma_qk.partition_A(gQ_qdl); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto [tQgQ_qdl, tQsQ] = tma_partition( + params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQ_qdl) + ); + Tensor tQgQ = tQgQ_qdl(_, _, _0{}, get<2>(blk_coord_q)); + + // compute gK, sK + Tensor mK_kdl_p = params.tma_load_k.get_tma_tensor(select<1,2,3>(problem_shape_qk)); + + int kv_offs_0 = 0; + + if constexpr (is_variable_length_v>) { + auto cumulative_length = get<1>(params_problem_shape).cumulative_length; + if (cumulative_length != nullptr) { + kv_offs_0 = cumulative_length[get<2,1>(blk_coord_kv)]; + get<2,1>(blk_coord_kv) = 0; + } + } + + Tensor mK_kdl = domain_offset(make_coord(kv_offs_0, _0{}, make_coord(_0{}, _0{})), mK_kdl_p); + + Tensor gK_kdl = local_tile(mK_kdl, TileShapeQK{}, make_coord(_, _, _), Step{}); + Tensor tSgK_kdl = mma_qk.partition_B(gK_kdl); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + auto [tKgK_kdl, tKsK] = tma_partition( + params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSgK_kdl) + ); + Tensor tKgK = tKgK_kdl(_, _, _0{}, get<2>(blk_coord_kv)); + + // compute gV, sV + ThrMMA mma_pv = typename CollectiveMmaPV::TiledMma{}.get_slice(0); + Tensor mV_dkl_p = params.tma_load_v.get_tma_tensor(select<2,1,3>(problem_shape_v)); + + Tensor mV_dkl = domain_offset(make_coord(_0{}, kv_offs_0, make_coord(_0{}, _0{})), mV_dkl_p); + + Tensor gV_dkl = local_tile(mV_dkl, TileShapePV{}, make_coord(_, _, _), Step{}); + Tensor tOgV_dkl = mma_pv.partition_B(gV_dkl); + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + auto [tVgV_dkl, tVsV] = tma_partition( + params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tOgV_dkl) + ); + auto tVgV = tVgV_dkl(_, _0{}, _, get<2>(blk_coord_kv)); + + // blk_coord in decomposed in terms of TileShape, not TileShapeQK + // As such, it needs to be transformed as + // (a,b,c): a -> 2*a (Q0) 2*a+1 (Q1) + // b -> 2*a (Ki i even) 2*a+1 (Ki i odd) + + uint32_t lane_predicate = cute::elect_one_sync(); + + // Q1 + int q0_index = 2 * get<0>(blk_coord_q); + int q1_index = 2 * get<0>(blk_coord_q) + 1; + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q0_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + // K1 + int k_index = 0; + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + + // Q2 + pipeline_q.producer_acquire(pipeline_q_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_q.producer_get_barrier(pipeline_q_producer_state); + copy(params.tma_load_q.with(*tma_barrier, 0), tQgQ(_, q1_index), tQsQ(_, pipeline_q_producer_state.index())); + } + ++pipeline_q_producer_state; + + if constexpr (cute::is_same_v) { + cutlass::arch::NamedBarrier::sync((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + // V1 + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + } + ++pipeline_kv_producer_state; + k_index += 1; + + // loop: + mask_tile_count -= 1; + for (; mask_tile_count > 0; mask_tile_count -= 1) { + + // Ki + pipeline_kv.producer_acquire(pipeline_kv_producer_state); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_k.with(*tma_barrier, 0), tKgK(_, k_index), tKsK(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch vi + cute::prefetch(params.tma_load_v, tVgV(_, k_index)); + } + ++pipeline_kv_producer_state; + + // Vi + pipeline_kv.producer_acquire_bytes(pipeline_kv_producer_state, TransactionBytesLoadV); + if (lane_predicate) { + auto tma_barrier = pipeline_kv.producer_get_barrier(pipeline_kv_producer_state); + copy(params.tma_load_v.with(*tma_barrier, 0), tVgV(_, k_index), tVsV(_, pipeline_kv_producer_state.index() / 2)); + + // prefetch ki+1 + if(mask_tile_count > 1) { + cute::prefetch(params.tma_load_k, tKgK(_, k_index + 1)); + } + } + ++pipeline_kv_producer_state; + k_index += 1; + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5bbeed9106b9c9f2631fe5c5ab635213295f92b7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pipeline_mla.hpp @@ -0,0 +1,250 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Support the producer to acquire specific bytes of data. +*/ + +#pragma once + +#include "cutlass/pipeline/sm100_pipeline.hpp" + +namespace cutlass { + +using namespace cute; + +template < + int Stages_, + class ClusterShape = Shape, + class AtomThrShape_MNK_ = Shape<_1,_1,_1> +> +class PipelineTmaAsyncMla { + +public: + static constexpr uint32_t Stages = Stages_; + using AtomThrShape_MNK = AtomThrShape_MNK_; + +private: + using Impl = PipelineTmaUmmaAsync; + +public: + using FullBarrier = typename Impl::FullBarrier; + using EmptyBarrier = typename Impl::EmptyBarrier; + using ProducerBarrierType = typename Impl::ProducerBarrierType; + using ConsumerBarrierType = typename Impl::ConsumerBarrierType; + using PipelineState = typename Impl::PipelineState; + using SharedStorage = typename Impl::SharedStorage; + using ThreadCategory = typename Impl::ThreadCategory; + using Params = typename Impl::Params; + + + using McastDirection = McastDirection; + + // Helper function to initialize barriers + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape) { + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + auto atom_thr_shape = AtomThrShape_MNK{}; + uint32_t const multicast_consumer_arrival_count = (cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape)) + + (cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape)) - 1; + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + static + CUTLASS_DEVICE + void + init_barriers(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction) { + auto atom_thr_shape = AtomThrShape_MNK{}; + + int warp_idx = canonical_warp_idx_sync(); + if (warp_idx == params.initializing_warp) { + // Barrier FULL and EMPTY init + constexpr int producer_arv_cnt = 1; + uint32_t const multicast_consumer_arrival_count = (mcast_direction == McastDirection::kRow) ? + cute::size<1>(cluster_shape) / cute::size<1>(atom_thr_shape) : // Mcast with row ctas + cute::size<0>(cluster_shape) / cute::size<0>(atom_thr_shape); // Mcast with col ctas + + cutlass::arch::detail::initialize_barrier_array_pair_aligned( + storage.full_barrier_, storage.empty_barrier_, producer_arv_cnt, multicast_consumer_arrival_count); + } + cutlass::arch::fence_barrier_init(); + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, dim3 block_id_in_cluster = cute::block_id_in_cluster()) { + // Calculate consumer mask + if (params_.role == ThreadCategory::Consumer) { + auto cluster_layout = make_layout(cluster_shape); + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + CUTLASS_DEVICE + void init_masks(ClusterShape cluster_shape, McastDirection mcast_direction) { + // Calculate consumer mask + dim3 block_id_in_cluster = cute::block_id_in_cluster(); + auto cluster_layout = make_layout(cluster_shape); + if (mcast_direction == McastDirection::kRow) { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + else { + block_id_mask_ = detail::calculate_multicast_mask(cluster_shape, AtomThrShape_MNK{}, block_id_in_cluster); + } + } + + +public: + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape); + } + } + + template + CUTLASS_DEVICE + PipelineTmaAsyncMla(SharedStorage& storage, Params params, ClusterShape cluster_shape, McastDirection mcast_direction, InitBarriers = {}, InitMasks = {}) + : impl_(storage, params, cluster_shape, cute::false_type{}, cute::false_type{}) + , params_(params) + , empty_barrier_ptr_(&storage.empty_barrier_[0]) + , full_barrier_ptr_(&storage.full_barrier_[0]) { + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_barriers(storage, params_, cluster_shape, mcast_direction); + } + + static_assert(cute::is_same_v || cute::is_same_v); + if constexpr (cute::is_same_v) { + init_masks(cluster_shape, mcast_direction); + } + } + + + CUTLASS_DEVICE + void producer_acquire(PipelineState state, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.producer_acquire(state, barrier_token); + } + + CUTLASS_DEVICE + void producer_acquire_bytes(uint32_t stage, uint32_t bytes, uint32_t phase, ProducerToken barrier_token) { + detail::pipeline_check_is_producer(params_.role); + if (barrier_token != BarrierStatus::WaitDone) { + empty_barrier_ptr_[stage].wait(phase); + } + + if (params_.is_leader) { + full_barrier_ptr_[stage].arrive_and_expect_tx(bytes); + } + #ifndef NDEBUG + if (params_.role == ThreadCategory::Consumer || params_.role == ThreadCategory::NonParticipant) { + asm volatile ("brkpt;\n" ::); + } + + // Most likely you have elected more than one leader + if (params_.is_leader && (threadIdx.x % 32 != 0)) { + asm volatile ("brkpt;\n" ::); + } + #endif + } + + CUTLASS_DEVICE + void producer_acquire_bytes(PipelineState state, uint32_t bytes, ProducerToken barrier_token = {BarrierStatus::WaitAgain}) { + producer_acquire_bytes(state.index(), bytes, state.phase(), barrier_token); + } + + CUTLASS_DEVICE + ProducerBarrierType* producer_get_barrier(PipelineState state) { + return impl_.producer_get_barrier(state); + } + + CUTLASS_DEVICE + void consumer_wait(PipelineState state, ConsumerToken barrier_token = {BarrierStatus::WaitAgain}) { + impl_.consumer_wait(state, barrier_token); + } + + CUTLASS_DEVICE + void consumer_release(PipelineState state) { + consumer_release(state.index(), false); + } + +private: + Impl impl_; + Params params_; + EmptyBarrier *empty_barrier_ptr_; + FullBarrier *full_barrier_ptr_; + uint16_t block_id_mask_ = 0; + static constexpr bool is_2sm_mma = size(AtomThrShape_MNK{}) > 1; + + // Consumer signalling Producer of completion + // Ensures all blocks in the Same Row and Column get notifed. + CUTLASS_DEVICE + void consumer_release(uint32_t stage, uint32_t skip) { + detail::pipeline_check_is_consumer(params_.role); + uint64_t* smem_ptr = reinterpret_cast(&empty_barrier_ptr_[stage]); + if constexpr (is_2sm_mma) { // Mma cluster shape is 2x1 + if (!skip) { + cutlass::arch::umma_arrive_multicast_2x1SM(smem_ptr, block_id_mask_); + } + } + else { + if (!skip) { + if constexpr (cute::is_static_v and size(ClusterShape{}) == 1) { + cutlass::arch::umma_arrive(smem_ptr); + } + else { + cutlass::arch::umma_arrive_multicast(smem_ptr, block_id_mask_); + } + } + } + } +}; + +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eca93250f4c0cfdafac485ce5857f10469b3540d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/common/pow_2.hpp @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace cutlass::fmha { + +struct Pow2 { + int n; + int log2_n; + + explicit CUTE_DEVICE Pow2(int n) : n(n) { +#ifdef __CUDA_ARCH__ + log2_n = __ffs(n) - 1; +#endif + } + + template + CUTE_HOST_DEVICE T operator *(T const& b) const { + return n * b; + } + + template + CUTE_HOST_DEVICE auto operator *(Int const&) const { + if constexpr (N & (N - 1) == 0) { + return Pow2{n * N}; + } + return n * N; + } + +}; + +template +CUTE_HOST_DEVICE auto operator/(T const& a, Pow2 const& b) { + return a >> b.log2_n; +} + +template +CUTE_HOST_DEVICE auto operator%(T const& a, Pow2 const& b) { + return a & (b.n - 1); +} + +template +CUTE_HOST_DEVICE bool operator<(T const& a, Pow2 const& b) { + return a < b.n; +} + +CUTE_HOST_DEVICE void print(Pow2 const& a) { + printf("2^%d", a.log2_n); +} + +} // end namespace cutlass::fmha + +namespace cute { + +template <> +struct is_integral : true_type {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8406d3eb14de7e51aa069332fece686fbf4fa80 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha.hpp @@ -0,0 +1,276 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FMHA { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("FMHA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("FMHA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FMHA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e4efb34ec54bfb7183b6f6f1fdd80ea21ac7e28 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,375 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/tensor.hpp" + +#include "../device/fmha.hpp" +#include "../kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp" +#include "../kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class Element, + class ElementAccumulator, + class TileShape, + bool IsMla, + class Mask +> +class Sm100FmhaBwd { +private: + template + constexpr static auto to_bwd_shape(T shape) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(shape))::value; + auto HB = get(shape); + auto rest = take<0,R-1>(shape); + return append(rest, make_shape(size<0>(HB), get<1>(HB))); + } + else { + return shape; + } + } + + template + constexpr static auto to_bwd_stride(T stride) { + if constexpr (IsMla) { // remove GQA mode + constexpr int R = decltype(rank(stride))::value; + auto HB = get(stride); + auto rest = take<0,R-1>(stride); + if constexpr (is_same_v(HB))>, _0>) { + return append(rest, make_stride(get<0,1>(HB), get<1>(HB))); + } + else { + return append(rest, make_stride(get<0,0>(HB), get<1>(HB))); + } + } + else { + return stride; + } + } + +public: + /// Argument structure: User API + struct Arguments { + // Q K D D_VO HB + ProblemShape problem_shape; + + const Element* ptr_Q; + cute::tuple, int>> stride_Q; + const Element* ptr_K; + cute::tuple, int>> stride_K; + const Element* ptr_V; + cute::tuple, int>> stride_V; + + const Element* ptr_O; + cute::tuple, int>> stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple, int>> stride_LSE; + + const Element* ptr_dO; + cute::tuple, int>> stride_dO; + + Element* ptr_dQ; + cute::tuple, int>> stride_dQ; + Element* ptr_dK; + cute::tuple, int>> stride_dK; + Element* ptr_dV; + cute::tuple, int>> stride_dV; + + ElementAccumulator softmax_scale; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdSumOdO + >; + using OperationConvert = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::FmhaKernelBwdConvert + >; + + using OperationNormal= cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdKernelTmaWarpSpecialized< + ProblemShape, Element, ElementAccumulator, TileShape, Mask + > + >; + + using ProblemShapeMLA = decltype(to_bwd_shape(ProblemShape{})); + using OperationMla = cutlass::fmha::device::FMHA< + cutlass::fmha::kernel::Sm100FmhaBwdMlaKernelTmaWarpSpecialized< + ProblemShapeMLA, Element, ElementAccumulator, TileShape, Mask + > + >; + + using Operation = std::conditional_t; + + using Kernel = typename Operation::Kernel; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments( + Arguments const& args, + ElementAccumulator* sum_odo = nullptr, + ElementAccumulator* scaled_lse = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_sum_OdO = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); + auto stride_scaled_lse = make_stride(_1{}, make_stride(make_stride(Q, Q*H_R), B == 1 ? 0 : Q*H_R*H_K)); + auto log2_e = log2f(expf(1.0f)); + return typename OperationSumOdO::Arguments { + args.problem_shape, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + sum_odo, stride_sum_OdO, + args.ptr_LSE, args.stride_LSE, + scaled_lse, stride_scaled_lse, + -1.0f, -log2_e + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + using namespace cute; + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + auto stride_src_dQ = make_stride(D, _1{}, make_stride(make_stride(D*Q, D*Q*H_R), B == 1 ? 0 : D*Q*H_R*H_K)); + return typename OperationConvert::Arguments { + args.problem_shape, + src, stride_src_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV, + args.softmax_scale + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple, int>> const& stride_sum_OdO = {}, + ElementAccumulator* scaled_lse = nullptr, cute::tuple, int>> const& stride_scaled_lse = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple, int>> const& stride_dQ = {}) { + + return typename Operation::Arguments{ + to_bwd_shape(args.problem_shape), + { args.ptr_Q, to_bwd_stride(args.stride_Q), + args.ptr_K, to_bwd_stride(args.stride_K), + args.ptr_V, to_bwd_stride(args.stride_V), + args.ptr_dO, to_bwd_stride(args.stride_dO), + scaled_lse, to_bwd_stride(stride_scaled_lse), + sum_OdO, to_bwd_stride(stride_sum_OdO), + dQ_acc, to_bwd_stride(stride_dQ), + args.softmax_scale }, + { args.ptr_dK, to_bwd_stride(args.stride_dK), + args.ptr_dV, to_bwd_stride(args.stride_dV) }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // scaled LSE vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, void* workspace_scaled_lse, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_scaled_lse); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments( + args, sum_OdO, args_sum_OdO.stride_sum_OdO, + scaled_lse, args_sum_OdO.stride_scaled_lse, + dQ_acc, args_convert.stride_src_dQ + ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [Q_, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = product_each(HB); + D = cutlass::round_up(D, 8); // Alignment + int Q = cutlass::round_up(static_cast(Q_), 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* scaled_lse = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4706d9fee5ab523e5c4f62cefd8ca0370178f253 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/device/sm100_mla.hpp @@ -0,0 +1,383 @@ +/*************************************************************************************************** + * Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +#include "kernel/sm100_fmha_mla_tma_warpspecialized.hpp" +#include "kernel/sm100_fmha_mla_reduction.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +using namespace cute; +using namespace cutlass::fmha::kernel; + + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template< + class Kernel_ +> +class MLA { +public: + + using Kernel = Kernel_; + + using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel< + typename Kernel::ElementOut, + typename Kernel::ElementAcc, + typename Kernel::ElementAcc, + Kernel::TileShapeH::value, + Kernel::TileShapeL::value, + 256 /*Max split*/ + >; + + /// Argument structure: User API + using KernelArguments = typename Kernel::Arguments; + using ReductionArguments = typename ReductionKernel::Arguments; + + using Arguments = KernelArguments; + + /// Argument structure: Kernel API + using KernelParams = typename Kernel::Params; + using ReductionParams = typename ReductionKernel::Params; + struct Params { + KernelParams fmha_params; + ReductionParams reduction_params; + }; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + + static ReductionArguments to_reduction_args(Arguments const& args) { + auto [H, K, D, B] = args.problem_shape; + return ReductionArguments{ + nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse, + args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq, + args.ptr_split_kv, Kernel::TileShapeS::value + }; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + static void set_split_kv (KernelArguments& args) { + if (args.split_kv >= 1) return; + auto [H, K, D, B] = args.problem_shape; + int sm_count = args.hw_info.sm_count; + int max_splits = ceil_div(K, 128); + int sms_per_batch = max(1, sm_count / B); + int split_heur = min(max_splits, sms_per_batch); + int waves = ceil_div(B * split_heur, sm_count); + int k_waves = ceil_div(max_splits, split_heur); + int split_wave_aware = ceil_div(max_splits, k_waves); + if (args.is_fused_reduction && split_wave_aware > 1) { + args.split_kv = std::min(split_wave_aware, static_cast(sm_count/2)); + } else { + args.split_kv = split_wave_aware; + } + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (! Kernel::can_implement(args)) { + return Status::kInvalid; + } + if (! ReductionKernel::can_implement(to_reduction_args(args))) { + return Status::kInvalid; + } + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args)); + return workspace_bytes; + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream); + if (status != Status::kSuccess) { + return status; + } + KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {kernel_params, reduction_params}; + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + // no dynamic smem is needed for reduction kernel + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + auto fmha_params = Kernel::to_underlying_arguments(args, workspace); + + ReductionArguments reduction_args = to_reduction_args(args); + if (reduction_args.split_kv > 1) { + reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc; + reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc; + } + ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace); + // Initialize the Params structure + params_ = Params {fmha_params, reduction_params}; + + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("MLA::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = Kernel::get_grid_shape(params.fmha_params); + auto [H, K, D, B] = params.fmha_params.problem_shape; + auto [D_latent, D_rope] = D; + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + if (params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) { + auto result = cudaMemsetAsync(params.fmha_params.epilogue.ptr_o, 0, sizeof(typename Kernel::ElementOut) * H * D_latent * B, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + auto total_bytes = H * B * (sizeof(int) + sizeof(typename Kernel::ElementLSE)) + 2 * B * sizeof(int); + uint8_t* ws = reinterpret_cast(params.fmha_params.epilogue.ptr_lse_exchange_buff); + result = cudaMemsetAsync(ws, 0, total_bytes, stream); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal;; + } + } + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms.fmha_params}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + device_kernel<<>>(params.fmha_params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess != result or Status::kSuccess != launch_result) { + //return Status::kSuccess; + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + if (!params.fmha_params.is_fused_reduction && params.reduction_params.split_kv > 1) { + // launch reduction kernel + dim3 const block = ReductionKernel::get_block_shape(); + dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params); + device_kernel<<>>(params.reduction_params); + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + else { + return Status::kSuccess; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..572e67f6967c1e0c0d675cbadf120b8cf6114707 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_causal_tile_scheduler.hpp @@ -0,0 +1,197 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +// Swizzle Q tile and H tile to improve L2 cache hit rate, +// and launch the longest main loop first to keep most SMs busy. + +struct CausalIndividualTileScheduler { + + static constexpr int TileQ = 16; + static constexpr int TileH = 8; + static constexpr int TileSize = TileQ * TileH; + + struct Params { + dim3 grid; + int tile_max_q; + FastDivmod divmod_tile_col; + FastDivmod divmod_tile_size; + FastDivmod divmod_tile_head; + }; + + bool valid_ = true; + Params params; + + CUTLASS_DEVICE + CausalIndividualTileScheduler(Params const& params) : params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + + dim3 grid(size<3,0>(problem_size), round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,1>(problem_size)); + // gridDim.x must multiple of TileH + const int tile_col_count = grid.x / TileH; + const int tile_max_q = grid.y / TileQ * TileQ; + return Params{ grid , tile_max_q, tile_col_count, TileSize, TileH}; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + const int block_idx = blockIdx.y * gridDim.x + blockIdx.x; + + int tile_idx, tile_tail; + params.divmod_tile_size(tile_idx, tile_tail, block_idx); + + int tile_row_idx, tile_col_idx; + params.divmod_tile_col(tile_row_idx,tile_col_idx, tile_idx); + + int row_offset_in_tail, col_offset_in_tail; + params.divmod_tile_head(row_offset_in_tail,col_offset_in_tail, tile_tail); + + const int row_idx = tile_row_idx * TileQ + row_offset_in_tail; + const int col_idx = tile_col_idx * TileH + col_offset_in_tail; + + // last q tile launch first + if(blockIdx.y >= params.tile_max_q) { + return make_coord(int(gridDim.y - 1 - blockIdx.y), _0{}, make_coord(int(blockIdx.x), int(blockIdx.z))); + } + + return make_coord(int(gridDim.y) - 1 - row_idx, _0{}, make_coord(col_idx, int(blockIdx.z))); + } + + CUTLASS_DEVICE + CausalIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +// Launch order: H Q B +struct CausalPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_h; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + CausalPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_h(block_decode, bidh, block_decode); + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + CausalPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7bee3d5f77c6518a13201a310a6adffeebc25432 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,153 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + ProblemShape problem_shape; + + const ElementAcc* ptr_src_dQ; + tuple, int>> stride_src_dQ; + const ElementAcc* ptr_src_dK; + tuple, int>> stride_src_dK; + const ElementAcc* ptr_src_dV; + tuple, int>> stride_src_dV; + + Element* ptr_dest_dQ; + tuple, int>> stride_dest_dQ; + Element* ptr_dest_dK; + tuple, int>> stride_dest_dK; + Element* ptr_dest_dV; + tuple, int>> stride_dest_dV; + + ElementAcc scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<4,0>(params.problem_shape), size<4,1>(params.problem_shape), ceil_div(std::max(size<0>(params.problem_shape), size<1>(params.problem_shape)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAcc* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, Count const& count, int d_dim) { + auto ptr_src_bh = ptr_src + get<2,0,0>(stride_src) * blockIdx.x + get<2,1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<2,0,0>(stride_dest) * blockIdx.x + get<2,1>(stride_dest) * blockIdx.y; + + int seqlen = count; + if constexpr (is_variable_length_v) { + int offset = count.cumulative_length[blockIdx.y]; + ptr_dest_bh += offset * get<0>(stride_dest); + seqlen = count.cumulative_length[blockIdx.y + 1] - offset; + } + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= seqlen) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<0>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<0>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < d_dim; idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAcc value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = static_cast(params.scale * value_src[v]); + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<0>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<1>(params.problem_shape), get<2>(params.problem_shape)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<1>(params.problem_shape), get<3>(params.problem_shape)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66780f353da60bd64347f06f30526873a72a21f2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + ProblemShape problem_shape; + + const Element* ptr_O; + cute::tuple, int>> stride_O; + const Element* ptr_dO; + cute::tuple, int>> stride_dO; + + ElementAcc* ptr_sum_OdO; + cute::tuple, int>> stride_sum_OdO; + + const ElementAcc* ptr_lse = nullptr; + cute::tuple, int>> stride_lse; + + ElementAcc* ptr_scaled_lse = nullptr; + cute::tuple, int>> stride_scaled_lse; + + ElementAcc sum_odo_scale = 1.0; + ElementAcc lse_scale = 1.0; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<2>(args.problem_shape) % kElementsPerLoad == 0 && get<3>(args.problem_shape) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<0>(params.problem_shape), kBlockQ), size<4,0>(params.problem_shape), size<4,1>(params.problem_shape)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<2,0,0>(params.stride_O) + blockIdx.z * get<2,1>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<2,0,0>(params.stride_dO) + blockIdx.z * get<2,1>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1,0,0>(params.stride_sum_OdO) + blockIdx.z * get<1,1>(params.stride_sum_OdO); + auto ptr_lse_bh = params.ptr_lse + blockIdx.y * get<1,0,0>(params.stride_lse) + blockIdx.z * get<1,1>(params.stride_lse); + auto ptr_scaled_lse_bh = params.ptr_scaled_lse + blockIdx.y * get<1,0,0>(params.stride_scaled_lse) + blockIdx.z * get<1,1>(params.stride_scaled_lse); + + auto problem_q = get<0>(params.problem_shape); + int seqlen_q = problem_q; + if constexpr (is_variable_length_v) { + int offset = problem_q.cumulative_length[blockIdx.z]; + ptr_O_bh += offset * get<0>(params.stride_O); + ptr_dO_bh += offset * get<0>(params.stride_dO); + ptr_lse_bh += offset * get<0>(params.stride_lse); + seqlen_q = problem_q.cumulative_length[blockIdx.z + 1] - offset; + } + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= seqlen_q) continue; + ElementAcc acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<0>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<0>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<0>(params.stride_sum_OdO); + auto ptr_lse_bhq = ptr_lse_bh + idx_q * get<0>(params.stride_lse); + auto ptr_scaled_lse_bhq = ptr_scaled_lse_bh + idx_q * get<0>(params.stride_scaled_lse); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<3>(params.problem_shape); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = params.sum_odo_scale * acc; + if (params.ptr_scaled_lse) { + *ptr_scaled_lse_bhq = params.lse_scale * *ptr_lse_bhq; + } + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d4faa8d2151497fe1151ba879111c43cd1e4ccc7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_options.hpp @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..119f069ccee5eb757edbf368e1cc2e66b3e5253c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + dim3 grid(round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<3,0>(problem_size), size<3,1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_h; + FastDivmod divmod_b; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<0>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<3,0>(problem_size) * size<3,1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidh, bidb)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..742b507d4deb31ba9a9ae8b7bef3cc53b076014a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1859 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeK = decltype(get<1>(TileShape{})); + static_assert(std::is_same_v, "tile shape K must be 128"); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<2>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + max(TileShapeQ{}, TileShapeDQK{}); + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kS + TileShapeQ{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride, int>>; + using TensorStrideContiguousMN = Stride<_1, int, Stride, int>>; + using TensorStrideContiguousK_GQA = Stride, int>>; + using TensorStrideContiguousMN_GQA = Stride<_1, int, Stride, int>>; + + // compute S + using CollectiveMmaKQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK_GQA, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeKQ = typename CollectiveMmaKQ::TileShape; + using TiledMmaKQ = typename CollectiveMmaKQ::TiledMma; + + // compute dP + using CollectiveMmaVDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK_GQA, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeVDO = typename CollectiveMmaVDO::TileShape; + using TiledMmaVDO = typename CollectiveMmaVDO::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = decltype(to_tiled_mma_sm100_ts(typename CollectiveMmaPDO::TiledMma{})); + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN_GQA, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaKQ::SmemLayoutA{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaVDO::SmemLayoutA{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaKQ::SmemLayoutB{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaVDO::SmemLayoutB{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using TensorStride_GQA = TensorStrideContiguousK_GQA; + using RowTensorStride = Stride<_1, Stride, int>>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride_GQA stride_k; + const Element* ptr_v; + TensorStride_GQA stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaKQ::Params::TMA_A; + using TMA_V = typename CollectiveMmaVDO::Params::TMA_A; + using TMA_Q = typename CollectiveMmaKQ::Params::TMA_B; + using TMA_DO = typename CollectiveMmaVDO::Params::TMA_B; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(make_shape(1,1), 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride_GQA stride_dk; + Element* ptr_dv; + TensorStride_GQA stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + if (Q <= 0 || K <= 0 || D <= 0 || D_VO <= 0 || H_R <= 0 || H_K <= 0 || B <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaKQ::to_underlying_arguments( + make_shape(K, Q, D, HB), + typename CollectiveMmaKQ::Arguments { + args.mainloop.ptr_k, args.mainloop.stride_k, + args.mainloop.ptr_q, args.mainloop.stride_q, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaVDO::to_underlying_arguments( + make_shape(K, Q, D_VO, HB), + typename CollectiveMmaVDO::Arguments { + args.mainloop.ptr_v, args.mainloop.stride_v, + args.mainloop.ptr_do, args.mainloop.stride_do, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_a, + params_vdo.tma_load_a, + params_kq.tma_load_b, + params_vdo.tma_load_b, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeKQ{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQ = local_tile(mQ, TileShapeKQ{}, make_coord(_,_,_), Step{}); + auto gV = local_tile(mV, TileShapeVDO{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gDO = local_tile(mDO, TileShapeVDO{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_kq = TiledMmaKQ{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaVDO{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_A(gK); + auto tSTgQ = cta_mma_kq.partition_B(gQ); + auto tDPTgV = cta_mma_vdo.partition_A(gV); + auto tDPTgDO = cta_mma_vdo.partition_B(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading 128 values of 32b each + // so 4*32b=128b + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * 4; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * 4; + for (int i = 0; i < 4; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaKQ::make_fragment_A(sK); + Tensor tSTrQ = TiledMmaKQ::make_fragment_B(sQ); + + Tensor tDPTrV = TiledMmaVDO::make_fragment_A(sV); + Tensor tDPTrDO = TiledMmaVDO::make_fragment_B(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + tDVrP.data() = TmemAllocation::kP; + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaKQ tiled_mma_kq; + TiledMmaVDO tiled_mma_vdo; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_kq, select<0,1>(TileShapeKQ{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_vdo, select<0,1>(TileShapeVDO{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_kq, + tSTrK(_,_,k_block,_0{}), + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTtST); + tiled_mma_kq.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_vdo, + tDPTrV(_,_,k_block,_0{}), + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTtDPT); + tiled_mma_vdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + auto store_op = []() { + if constexpr (sizeof(Element) == 1) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else { + return SM100_TMEM_STORE_32dp32b8x{}; + } + }(); + + Tensor tSTtST = partition_fragment_C(TiledMmaKQ{}, select<0,1>(TileShapeKQ{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaVDO{}, select<0,1>(TileShapeVDO{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeKQ{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeVDO{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + + } + }; + + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + // Tensor tTR_tST_p = thread_t2r.partition_S(tSTtST); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + auto sP = make_tensor(make_smem_ptr((Element*) nullptr), typename CollectiveMmaPDO::SmemLayoutA{}); + + auto tDVrP = TiledMmaPDO::make_fragment_A(sP)(_, _, _, _0{}); + auto tDVcST = TiledMmaPDO{}.get_slice(_0{}).partition_A(cST); + tDVrP.data() = TmemAllocation::kP; + + auto tiled_r2t = make_tmem_copy(store_op, tDVrP); + auto thread_r2t = tiled_r2t.get_slice(dp_idx); + + auto tRT_tP = split_wg(thread_r2t.partition_D(tDVrP)); + auto tRT_cST_p = thread_r2t.partition_S(tDVcST); + auto tRT_cST = split_wg(tRT_cST_p); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} > get<1>(problem_shape); + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == iter_start); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == iter_end - 1) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<1>(c_transpose) + iter_index * TileShapeQ{}, get<0>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<1>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<1>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + auto tRT_rST_reshaped = make_tensor(tRT_rST.data(), shape(tRT_cST)); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + cute::copy(tiled_r2t, tRT_rST_reshaped, tRT_tP); + }); + + // notify for P + cutlass::arch::fence_view_async_tmem_store(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<1>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<1>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds.begin()), SmemLayoutDS{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_128{}, _128{}), + make_stride(_1{}, _0{}) + ); + + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(dp_idx, _).compose(make_layout(shape(tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + } + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_start, + int iter_end, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + int iter_index = iter_start; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_32dp32b32x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeKQ{}, make_coord(_,_,_), Step{}) + (_, _, _, _0{}, _); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index,blk_coord_batch)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + if (iter_index == iter_end) { + iter_index = iter_start; + get<0,0>(blk_coord_batch) += 1; + } + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(make_coord(0, blockIdx.y), blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_end = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + int iter_count = (iter_end - iter_start) * get<4,0,0>(problem_shape); + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_end, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } +#endif + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(ceil_div(K, TileShapeK{}), H_K, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bf72843a43ecd36a98ef9b938d872700057e59a9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_bwd_mla_kernel_tma_warpspecialized.hpp @@ -0,0 +1,1826 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "collective/fmha_common.hpp" + +#include + +namespace cutlass::fmha::kernel { + +using namespace cutlass::fmha::collective; + +using namespace cute; + +template< + class ProblemShape, + class Element, + class ElementAcc, + class TileShape, + class Mask +> +struct Sm100FmhaBwdMlaKernelTmaWarpSpecialized { + + using TileShapeQ = decltype(get<0>(TileShape{})); + using TileShapeK = decltype(get<1>(TileShape{})); + using TileShapeDQK = decltype(get<2>(TileShape{})); + using TileShapeDVO = decltype(get<3>(TileShape{})); + + using TmemAllocator = cute::TMEM::Allocator1Sm; + struct TmemAllocation { + static constexpr uint32_t kDK = 0; // TileShapeK x TileShapeDQK x acc + static constexpr uint32_t kDV = kDK + TileShapeDQK{}; // TileShapeK x TileShapeDVO x acc + static constexpr uint32_t kDQ = kDV + TileShapeDVO{}; // TileShapeQ x TileShapeDQK x acc + static constexpr uint32_t kDP = kDQ; // TileShapeK x TileShapeQ x inp + static constexpr uint32_t kS = kDQ + 65536 * 16; + static constexpr uint32_t kP = kS; + static constexpr uint32_t kTotal = kDQ + TileShapeDQK{}; + }; + + static_assert( + static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, + "using too much tmem" + ); + + enum class WarpRole { + Empty = 0x0, Load = 0x1, Mma = 0x2, Compute = 0x3, Reduce = 0x4 + }; + + static constexpr unsigned long long kWarpAssignment = 0x12'3333'3333'4444ull; + static constexpr int kNumComputeWarps = 8; + static constexpr int kNumReduceWarps = 4; + + static constexpr int kLoadPerThread = TileShapeQ{} / NumThreadsPerWarp; + static_assert(TileShapeQ{} % NumThreadsPerWarp == 0, "TileShapeQ must be divisible by NumThreadsPerWarp"); + CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + struct RegisterAllocation { + static constexpr int kWarpgroup0 = 160-8; + static constexpr int kWarpgroup1 = 128; + static constexpr int kWarpgroup2 = 96; + static constexpr int kReduce = kWarpgroup0; + static constexpr int kCompute = kWarpgroup1; + static constexpr int kMma = kWarpgroup2; + static constexpr int kEmpty = kWarpgroup2; + static constexpr int kLoad = kWarpgroup2; + + static_assert(kWarpgroup0 + 2 * kWarpgroup1 + kWarpgroup2 <= 512); + }; + + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = Shape<_1, _1, _1>; + using Schedule = cutlass::gemm::KernelTmaWarpSpecialized1SmSm100; + + static constexpr int MinBlocksPerMultiprocessor = 1; + static constexpr int kNumWarps = kNumComputeWarps + kNumReduceWarps + 4; + static constexpr int MaxThreadsPerBlock = NumThreadsPerWarp * kNumWarps; + + static constexpr int Alignment = 128 / sizeof_bits_v; + static constexpr int kStages = 2; + + using TensorStrideContiguousK = Stride>; + using TensorStrideContiguousMN = Stride<_1, int, Stride>; + + // compute S + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeQK = typename CollectiveMmaQK::TileShape; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + + // compute dP + using CollectiveMmaDOV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousK, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDOV = typename CollectiveMmaDOV::TileShape; + using TiledMmaDOV = typename CollectiveMmaDOV::TiledMma; + + // compute dV + using CollectiveMmaPDO = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // needs to match ordering of S calculation + Element, TensorStrideContiguousK, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapePDO = typename CollectiveMmaPDO::TileShape; + using TiledMmaPDO = typename CollectiveMmaPDO::TiledMma; + + // compute dK + using CollectiveMmaDSQ = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the next one + Element, TensorStrideContiguousK , Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSQ = typename CollectiveMmaDSQ::TileShape; + using TiledMmaDSQ = typename CollectiveMmaDSQ::TiledMma; + + // compute dQ + using CollectiveMmaDSK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + // somewhat arbitrary since we dump to smem, need to agree with the previous one + Element, TensorStrideContiguousMN, Alignment, + Element, TensorStrideContiguousMN, Alignment, + ElementAcc, + Shape, + ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TileShapeDSK = typename CollectiveMmaDSK::TileShape; + using TiledMmaDSK = typename CollectiveMmaDSK::TiledMma; + + // pipelines are named Pipeline + static constexpr int kStagesComputeSmem = 1; + using PipelineLoadMmaQ = PipelineTmaUmmaAsync<2, ClusterShape>; + using PipelineLoadMmaDO = PipelineTmaUmmaAsync<1, ClusterShape>; + using PipelineLoadComputeLSE = PipelineAsync<1>; + using PipelineLoadComputeSumOdO = PipelineAsync<1>; + using PipelineMmaComputeS = PipelineUmmaAsync<1>; + using PipelineMmaComputeDP = PipelineUmmaAsync<1>; + using PipelineMmaReduceDQ = PipelineUmmaAsync<1>; + using PipelineComputeMmaP = PipelineUmmaConsumerAsync<1>; + using PipelineComputeMmaDS = PipelineUmmaConsumerAsync; + using PipelineMmaComputeDKDV = PipelineUmmaAsync<2>; + static constexpr int kStagesReduceTmaStore = 2; + using PipelineReduceTmaStore = PipelineTmaStore; + + struct PipelineStorage { + alignas(16) typename PipelineLoadMmaQ::SharedStorage load_mma_q; + alignas(16) typename PipelineLoadMmaDO::SharedStorage load_mma_do; + alignas(16) typename PipelineLoadComputeLSE::SharedStorage load_compute_lse; + alignas(16) typename PipelineLoadComputeSumOdO::SharedStorage load_compute_sum_odo; + alignas(16) typename PipelineMmaComputeS::SharedStorage mma_compute_s; + alignas(16) typename PipelineMmaComputeDP::SharedStorage mma_compute_dp; + alignas(16) typename PipelineMmaReduceDQ::SharedStorage mma_reduce_dq; + alignas(16) typename PipelineComputeMmaP::SharedStorage compute_mma_p; + alignas(16) typename PipelineComputeMmaDS::SharedStorage compute_mma_ds; + alignas(16) typename PipelineMmaComputeDKDV::SharedStorage mma_compute_dkdv; + }; + + template + static CUTE_DEVICE constexpr auto restage(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutK = decltype(restage(typename CollectiveMmaQK::SmemLayoutB{})); + using SmemLayoutV = decltype(restage(typename CollectiveMmaDOV::SmemLayoutB{})); + using SmemLayoutQ = decltype(restage(typename CollectiveMmaQK::SmemLayoutA{}, _2{})); + using SmemLayoutDO = decltype(restage(typename CollectiveMmaDOV::SmemLayoutA{}, _1{})); + using SmemLayoutDS = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, Int{})); + using SmemLayoutLSE = Layout>; + using SmemLayoutSumOdO = Layout>; + + using SmemLayoutQT = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutB{}, _2{})); + using SmemLayoutKT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutB{})); + using SmemLayoutDST = decltype(restage(typename CollectiveMmaDSQ::SmemLayoutA{}, Int{})); + using SmemLayoutDOT = decltype(restage(typename CollectiveMmaPDO::SmemLayoutB{}, _1{})); + using SmemLayoutP = decltype(restage(typename CollectiveMmaPDO::SmemLayoutA{}, _1{})); + using SmemLayoutPT = decltype(restage(typename CollectiveMmaDSK::SmemLayoutA{}, _1{})); + + using TileShapeDQ = _32; + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::K, ElementAcc, TileShapeQ, TileShapeDQ + >()); + using SmemShapeDQ = Shape>; + using SmemLayoutDQ = decltype(tile_to_shape(SmemAtomDQ{}, SmemShapeDQ{}, Step<_2, _1, _3>{})); + + struct TensorStorage { + union { + alignas(2048) cute::array> smem_k; + alignas(2048) cute::array> smem_k_t; + }; + alignas(2048) cute::array> smem_v; + union { + alignas(2048) cute::array> smem_q; + alignas(2048) cute::array> smem_q_t; + }; + union { + alignas(2048) cute::array> smem_do; + alignas(2048) cute::array> smem_do_t; + }; + union { + alignas(2048) cute::array> smem_ds; + alignas(2048) cute::array> smem_ds_t; + }; + union{ + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array> smem_p_t; + }; + alignas(1024) cute::array> smem_dq; + alignas(16) cute::array> smem_lse; + alignas(16) cute::array> smem_sum_odo; + }; + + static constexpr int kTransactionsBytesLoadQ = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadDO = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutDO{})) * cute::sizeof_bits_v); + + static constexpr int kTransactionsBytesLoadK = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutK{})) * cute::sizeof_bits_v); + static constexpr int kTransactionsBytesLoadV = cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutV{})) * cute::sizeof_bits_v); + + struct SharedStorage { + TensorStorage tensors; + PipelineStorage pipelines; + uint32_t tmem_base_ptr; + }; + + // this is tight enough that it won't work with sizeof due to padding for alignment + static constexpr int SharedStorageSize = offsetof(SharedStorage, tmem_base_ptr) + sizeof(uint32_t); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + using TensorStride = TensorStrideContiguousK; // S D (H B) + using RowTensorStride = Stride<_1, Stride>; // S (H B) + + struct MainloopArguments { + const Element* ptr_q; + TensorStride stride_q; + const Element* ptr_k; + TensorStride stride_k; + const Element* ptr_v; + TensorStride stride_v; + const Element* ptr_do; + TensorStride stride_do; + + const ElementAcc* ptr_lse; + RowTensorStride stride_lse; + + const ElementAcc* ptr_sum_odo; + RowTensorStride stride_sum_odo; + + ElementAcc* ptr_dq_acc; + TensorStride stride_dq_acc; + + ElementAcc softmax_scale = 1.0f / sqrtf(TileShapeDQK{}); + }; + + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaDOV::Params::TMA_B; + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_DO = typename CollectiveMmaDOV::Params::TMA_A; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor((const ElementAcc*)nullptr, make_shape(1, 1, make_shape(1, 1)), TensorStride{}), + SmemLayoutDQ{}(_, _, _0{}) + )); + + struct MainloopParams { + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_Q tma_load_q; + TMA_DO tma_load_do; + TMA_DQ tma_red_dq; + }; + + struct EpilogueArguments { + Element* ptr_dk; + TensorStride stride_dk; + Element* ptr_dv; + TensorStride stride_dv; + }; + + struct Arguments { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + MainloopParams mainloop_params; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + }; + + + static bool can_implement(Arguments const& args) { + auto [Q, K, D, D_VO, HB] = args.problem_shape; + auto [H, B] = HB; + if (Q <= 0 || K <= 0 || D <= 0 || H <= 0 || B <= 0 || D_VO <= 0) { + return false; + } + if (D % Alignment != 0 || D_VO % Alignment != 0) { + return false; + } + return true; + } + + + static Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return Status::kSuccess; + } + + + static Params to_underlying_arguments(Arguments const& args, void*) { + auto [Q_, K_, D, D_VO, HB] = args.problem_shape; + int Q = Q_; + int K = K_; + + if constexpr (is_variable_length_v) { + Q = Q_.total_length; + } + if constexpr (is_variable_length_v) { + K = K_.total_length; + } + + auto params_kq = CollectiveMmaQK::to_underlying_arguments( + make_shape(Q, K, D, HB), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q, args.mainloop.stride_q, + args.mainloop.ptr_k, args.mainloop.stride_k, + }, /*workspace=*/nullptr); + + auto params_vdo = CollectiveMmaDOV::to_underlying_arguments( + make_shape(Q, K, D_VO, HB), + typename CollectiveMmaDOV::Arguments { + args.mainloop.ptr_do, args.mainloop.stride_do, + args.mainloop.ptr_v, args.mainloop.stride_v, + }, /*workspace=*/nullptr); + + TMA_DQ tma_red_dq = make_tma_copy( + SM90_TMA_REDUCE_ADD{}, + make_tensor(args.mainloop.ptr_dq_acc, make_shape(Q_, D, HB), args.mainloop.stride_dq_acc), + SmemLayoutDQ{}(_, _, _0{}) + ); + + return Params{ + args.problem_shape, + args.mainloop, + MainloopParams{ + params_kq.tma_load_b, + params_vdo.tma_load_b, + params_kq.tma_load_a, + params_vdo.tma_load_a, + tma_red_dq + }, + args.epilogue, + args.hw_info + }; + } + + + template + static CUTLASS_DEVICE auto quantize(T const& input) { + constexpr int AlignmentS = 4; + auto output = make_tensor(shape(input)); + auto input_vec = recast>(input); + auto output_vec = recast>(output); + + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(input_vec); i++) { + output_vec(i) = epilogue_op(input_vec(i)); + } + + return output; + } + + + template + CUTLASS_DEVICE void load( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_producer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_producer_state, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_producer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + using X = Underscore; + + uint16_t mcast_mask = 0; + + auto mK_in = mainloop_params.tma_load_k.get_tma_tensor(make_shape(K, D, HB)); + auto mV_in = mainloop_params.tma_load_v.get_tma_tensor(make_shape(K, D_VO, HB)); + auto mQ_in = mainloop_params.tma_load_q.get_tma_tensor(make_shape(Q, D, HB)); + auto mDO_in = mainloop_params.tma_load_do.get_tma_tensor(make_shape(Q, D_VO, HB)); + + auto mK = domain_offset(select<1,2,4>(blk_offset), mK_in); + auto mV = domain_offset(select<1,3,4>(blk_offset), mV_in); + auto mQ = domain_offset(select<0,2,4>(blk_offset), mQ_in); + auto mDO = domain_offset(select<0,3,4>(blk_offset), mDO_in); + + auto gK = local_tile(mK, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gQ = local_tile(mQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gV = local_tile(mV, TileShapeDOV{}, make_coord(_,_,_), Step{}); + auto gDO = local_tile(mDO, TileShapeDOV{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_kq = TiledMmaQK{}.get_slice(_0{}); + ThrMMA cta_mma_vdo = TiledMmaDOV{}.get_slice(_0{}); + + auto tSTgK = cta_mma_kq.partition_B(gK); + auto tSTgQ = cta_mma_kq.partition_A(gQ); + auto tDPTgV = cta_mma_vdo.partition_B(gV); + auto tDPTgDO = cta_mma_vdo.partition_A(gDO); + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto [tKgK_mkl, tKsK] = tma_partition( + mainloop_params.tma_load_k, _0{}, make_layout(_1{}), + group_modes<0,3>(sK), group_modes<0,3>(tSTgK)); + auto [tQgQ_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSTgQ)); + auto [tVgV_mkl, tVsV] = tma_partition( + mainloop_params.tma_load_v, _0{}, make_layout(_1{}), + group_modes<0,3>(sV), group_modes<0,3>(tDPTgV)); + auto [tDOgDO_mkl, tDOsDO] = tma_partition( + mainloop_params.tma_load_do, _0{}, make_layout(_1{}), + group_modes<0,3>(sDO), group_modes<0,3>(tDPTgDO)); + + // set up lse and sum_odo + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + auto tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + pipeline_load_mma_q.producer_expect_transaction(pipeline_load_mma_q_producer_state, kTransactionsBytesLoadK); + + // load K + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_k.with(*tma_barrier, mcast_mask), + tKgK_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tKsK(_, _0{}) + ); + } + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + // 32 threads loading kLoadPerThread * 32 values of 32b each + + int thread_idx = threadIdx.x % NumThreadsPerWarp; + int smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + int gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mLSE = make_tensor(mainloop_args.ptr_lse, make_shape(Q, HB), mainloop_args.stride_lse); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + pipeline_load_mma_do.producer_expect_transaction(pipeline_load_mma_do_producer_state, kTransactionsBytesLoadV); + + // load V + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_v.with(*tma_barrier, mcast_mask), + tVgV_mkl(_, blk_coord_k, _0{}, blk_coord_batch), + tVsV(_, _0{}) + ); + } + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + auto mSumOdO = make_tensor(mainloop_args.ptr_sum_odo, make_shape(Q, HB), mainloop_args.stride_sum_odo); + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + + while (iter_count > 0) { + pipeline_load_mma_q.producer_acquire(pipeline_load_mma_q_producer_state); + tma_barrier = pipeline_load_mma_q.producer_get_barrier(pipeline_load_mma_q_producer_state); + + // load Q + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_q.with(*tma_barrier, mcast_mask), + tQgQ_mkl(_, iter_index, _0{}, blk_coord_batch), + tQsQ(_, pipeline_load_mma_q_producer_state.index()) + ); + } + + ++pipeline_load_mma_q_producer_state; + + pipeline_load_compute_lse.producer_acquire(pipeline_load_compute_lse_producer_state); + + // load LSE + smem_idx = TileShapeQ{} * pipeline_load_compute_lse_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_lse.begin() + smem_idx + i, + &mLSE(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_lse.producer_commit(pipeline_load_compute_lse_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_lse_producer_state; + + pipeline_load_mma_do.producer_acquire(pipeline_load_mma_do_producer_state); + tma_barrier = pipeline_load_mma_do.producer_get_barrier(pipeline_load_mma_do_producer_state); + + // load dO + if (cute::elect_one_sync()) { + cute::copy( + mainloop_params.tma_load_do.with(*tma_barrier, mcast_mask), + tDOgDO_mkl(_, iter_index, _0{}, blk_coord_batch), + tDOsDO(_, pipeline_load_mma_do_producer_state.index()) + ); + } + + ++pipeline_load_mma_do_producer_state; + + pipeline_load_compute_sum_odo.producer_acquire(pipeline_load_compute_sum_odo_producer_state); + + // load sum_OdO + smem_idx = TileShapeQ{} * pipeline_load_compute_sum_odo_producer_state.index() + thread_idx * kLoadPerThread; + gmem_idx = TileShapeQ{} * iter_index + thread_idx * kLoadPerThread; + for (int i = 0; i < kLoadPerThread; i++) { + cutlass::arch::cp_async_zfill<4>( + shared_tensors.smem_sum_odo.begin() + smem_idx + i, + &mSumOdO(gmem_idx + i, blk_coord_batch), + gmem_idx + i < Q + ); + } + + pipeline_load_compute_sum_odo.producer_commit(pipeline_load_compute_sum_odo_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_load_compute_sum_odo_producer_state; + + iter_count -= 1; + iter_index += 1; + } + } + + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelineLoadMmaQ& pipeline_load_mma_q, + typename PipelineLoadMmaQ::PipelineState& pipeline_load_mma_q_consumer_state, + PipelineLoadMmaDO& pipeline_load_mma_do, + typename PipelineLoadMmaDO::PipelineState& pipeline_load_mma_do_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_producer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_producer_state, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_producer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_consumer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_consumer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_producer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + auto sK = make_tensor(make_smem_ptr(shared_tensors.smem_k.begin()), SmemLayoutK{}); + auto sV = make_tensor(make_smem_ptr(shared_tensors.smem_v.begin()), SmemLayoutV{}); + auto sDO = make_tensor(make_smem_ptr(shared_tensors.smem_do.begin()), SmemLayoutDO{}); + + auto sQT = make_tensor(make_smem_ptr(shared_tensors.smem_q_t.begin()), SmemLayoutQT{}); + auto sKT = make_tensor(make_smem_ptr(shared_tensors.smem_k_t.begin()), SmemLayoutKT{}); + auto sDS = make_tensor(make_smem_ptr(shared_tensors.smem_ds.begin()), SmemLayoutDS{}); + auto sDST = make_tensor(make_smem_ptr(shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}); + auto sP = make_tensor(make_smem_ptr(shared_tensors.smem_p.begin()), SmemLayoutP{}); + auto sDOT = make_tensor(make_smem_ptr(shared_tensors.smem_do_t.begin()), SmemLayoutDOT{}); + + Tensor tSTrK = TiledMmaQK::make_fragment_B(sK); + Tensor tSTrQ = TiledMmaQK::make_fragment_A(sQ); + + Tensor tDPTrV = TiledMmaDOV::make_fragment_B(sV); + Tensor tDPTrDO = TiledMmaDOV::make_fragment_A(sDO); + + Tensor tDQrDS = TiledMmaDSK::make_fragment_A(sDS); + Tensor tDQrKT = TiledMmaDSK::make_fragment_B(sKT); + + Tensor tDKrDST = TiledMmaDSQ::make_fragment_A(sDST); + Tensor tDKrQT = TiledMmaDSQ::make_fragment_B(sQT); + + Tensor tDVrP = TiledMmaPDO::make_fragment_A(sP); + Tensor tDVrDOT = TiledMmaPDO::make_fragment_B(sDOT); + + TiledMmaQK tiled_mma_qk; + TiledMmaDOV tiled_mma_dov; + TiledMmaDSK tiled_mma_dsk; + TiledMmaDSQ tiled_mma_dsq; + TiledMmaPDO tiled_mma_pdo; + + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::Zero; + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::Zero; + + Tensor tSTtST = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(tiled_mma_dov, select<0,1>(TileShapeDOV{})); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor tDQtDQ = partition_fragment_C(tiled_mma_dsk, select<0,1>(TileShapeDSK{})); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor tDKtDK = partition_fragment_C(tiled_mma_dsq, select<0,1>(TileShapeDSQ{})); + tDKtDK.data() = TmemAllocation::kDK; + + Tensor tDVtDV = partition_fragment_C(tiled_mma_pdo, select<0,1>(TileShapePDO{})); + tDVtDV.data() = TmemAllocation::kDV; + + auto pipeline_load_mma_q_release_state = pipeline_load_mma_q_consumer_state; + + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + + // in tmem, S & P overlap + // and dP and dQ overlap + // so we need to acquire dQ and dP at the same time + while (iter_count > 0) { + pipeline_load_mma_q.consumer_wait(pipeline_load_mma_q_consumer_state); + pipeline_mma_compute_s.producer_acquire(pipeline_mma_compute_s_producer_state); + + // S = Q*K + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSTrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSTrQ(_,_,k_block,pipeline_load_mma_q_consumer_state.index()), + tSTrK(_,_,k_block,_0{}), + tSTtST); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + ++pipeline_load_mma_q_consumer_state; + + pipeline_mma_compute_s.producer_commit(pipeline_mma_compute_s_producer_state); + ++pipeline_mma_compute_s_producer_state; + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // we need to acquire dP here, because tmem dQ == tmem dP + pipeline_mma_compute_dp.producer_acquire(pipeline_mma_compute_dp_producer_state); + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + + // we grab dq here, because in tmem dq == dp + pipeline_mma_reduce_dq.producer_acquire(pipeline_mma_reduce_dq_producer_state); + + pipeline_load_mma_do.consumer_wait(pipeline_load_mma_do_consumer_state); + + // dP = dO*V + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDPTrV); ++k_block) { + cute::gemm(tiled_mma_dov, + tDPTrDO(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDPTrV(_,_,k_block,_0{}), + tDPTtDPT); + tiled_mma_dov.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_compute_dp.producer_commit(pipeline_mma_compute_dp_producer_state); + ++pipeline_mma_compute_dp_producer_state; + + pipeline_compute_mma_p.consumer_wait(pipeline_compute_mma_p_consumer_state); + + // dV = P*dO + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDVrP); ++k_block) { + cute::gemm(tiled_mma_pdo, + tDVrP(_,_,k_block,_0{}), + tDVrDOT(_,_,k_block,pipeline_load_mma_do_consumer_state.index()), + tDVtDV); + tiled_mma_pdo.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_compute_mma_p.consumer_release(pipeline_compute_mma_p_consumer_state); + ++pipeline_compute_mma_p_consumer_state; + + pipeline_load_mma_do.consumer_release(pipeline_load_mma_do_consumer_state); + ++pipeline_load_mma_do_consumer_state; + + iter_count -= 1; + } + + // signal to the epilogue that dV is ready + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + pipeline_mma_compute_dkdv.producer_acquire(pipeline_mma_compute_dkdv_producer_state); + + pipeline_compute_mma_ds.consumer_wait(pipeline_compute_mma_ds_consumer_state); + + // dK = dS*Q + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDKrDST); ++k_block) { + cute::gemm(tiled_mma_dsq, + tDKrDST(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDKrQT(_,_,k_block,pipeline_load_mma_q_release_state.index()), + tDKtDK); + tiled_mma_dsq.accumulate_ = UMMA::ScaleOut::One; + } + + // signal to epilgue that dK is ready + pipeline_mma_compute_dkdv.producer_commit(pipeline_mma_compute_dkdv_producer_state); + ++pipeline_mma_compute_dkdv_producer_state; + + // we've already acquired mma_reduce_dq in the loop + + // dQ = dS*K + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::Zero; + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tDQrDS); ++k_block) { + cute::gemm(tiled_mma_dsk, + tDQrDS(_,_,k_block,pipeline_compute_mma_ds_consumer_state.index()), + tDQrKT(_,_,k_block,_0{}), + tDQtDQ); + tiled_mma_dsk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_mma_reduce_dq.producer_commit(pipeline_mma_reduce_dq_producer_state); + ++pipeline_mma_reduce_dq_producer_state; + + pipeline_load_mma_q.consumer_release(pipeline_load_mma_q_release_state); + ++pipeline_load_mma_q_release_state; + + pipeline_compute_mma_ds.consumer_release(pipeline_compute_mma_ds_consumer_state); + ++pipeline_compute_mma_ds_consumer_state; + } + + + + template + CUTLASS_DEVICE void store( + TensorG gmem, + TensorR const& regs, + TensorC const& coord, + TensorShape const& tensor_shape) { + + Tensor preds = cute::lazy::transform(coord, [&](auto const& c) { return elem_less(c, tensor_shape); }); + + auto copy_op = make_cotiled_copy( + Copy_Atom, Element>{}, + make_layout(make_shape(_1{}, Int{})), + regs.layout() + ); + auto thr_copy = copy_op.get_slice(_0{}); + + Tensor quantized_regs = quantize(regs); + Tensor tCr = thr_copy.partition_S(quantized_regs); + Tensor tCg = thr_copy.partition_D(gmem); + Tensor tPc = thr_copy.partition_D(preds); + + copy_if(copy_op, tPc, tCr, tCg); + } + + + template + CUTLASS_DEVICE void epilogue_clear( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + for (int i = threadIdx.x; i < size(gDK); i += blockDim.x) { + if (elem_less(cDK(i), select<1,2>(problem_shape))) { + gDK(i) = Element(0); + } + } + for (int i = threadIdx.x; i < size(gDV); i += blockDim.x) { + if (elem_less(cDV(i), select<1,3>(problem_shape))) { + gDV(i) = Element(0); + } + } + + } + + + template + CUTLASS_DEVICE void epilogue( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + auto load_op = SM100_TMEM_LOAD_32dp32b16x{}; + + auto tDKtDK = partition_fragment_C(TiledMmaDSQ{}, select<0,1>(TileShapeDSQ{}))(make_coord(_,_),_0{},_0{}); + tDKtDK.data() = TmemAllocation::kDK; + + auto mDK_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dk), make_shape(K, TileShapeDQK{}, HB), epilogue_args.stride_dk); + auto mDK = domain_offset(select<1,2,4>(blk_offset), mDK_in); + auto gDK = local_tile(mDK, TileShapeDSQ{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDK = domain_offset( + make_coord(get<1>(blk_coord) * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapeDSQ{})) + ); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + }; + + auto tiled_t2r_dk = make_tmem_copy(load_op, tDKtDK); + auto thread_t2r_dk = tiled_t2r_dk.get_slice(dp_idx); + + Tensor tTR_cDK = split_wg(thread_t2r_dk.partition_D(cDK)); + Tensor tTR_gDK = split_wg(thread_t2r_dk.partition_D(gDK)); + Tensor tTR_rDK = make_tensor(shape(tTR_cDK)); + Tensor tTR_tDK = split_wg(thread_t2r_dk.partition_S(tDKtDK)); + + auto tDVtDV = partition_fragment_C(TiledMmaPDO{}, select<0,1>(TileShapePDO{}))(make_coord(_,_),_0{},_0{}); + tDVtDV.data() = TmemAllocation::kDV; + + auto mDV_in = make_tensor(make_gmem_ptr(epilogue_args.ptr_dv), make_shape(K, TileShapeDVO{}, HB), epilogue_args.stride_dv); + auto mDV = domain_offset(select<1,3,4>(blk_offset), mDV_in); + auto gDV = local_tile(mDV, TileShapePDO{}, make_coord(_,_,_), Step<_1, _1, X>{}) + (_, _, blk_coord_k, _0{}, blk_coord_batch); + + Tensor cDV = domain_offset( + make_coord(blk_coord_k * TileShapeK{}, _0{}), + make_identity_tensor(take<0,2>(TileShapePDO{})) + ); + + auto tiled_t2r_dv = make_tmem_copy(load_op, tDVtDV); + auto thread_t2r_dv = tiled_t2r_dv.get_slice(dp_idx); + + Tensor tTR_cDV = split_wg(thread_t2r_dv.partition_D(cDV)); + Tensor tTR_gDV = split_wg(thread_t2r_dv.partition_D(gDV)); + Tensor tTR_rDV = make_tensor(shape(tTR_cDV)); + Tensor tTR_tDV = split_wg(thread_t2r_dv.partition_S(tDVtDV)); + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDVtDV + cute::copy(tiled_t2r_dv, tTR_tDV, tTR_rDV); + + // store tDVgDV + store(tTR_gDV, tTR_rDV, tTR_cDV, select<1,3>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + pipeline_mma_compute_dkdv.consumer_wait(pipeline_mma_compute_dkdv_consumer_state); + + // load tDKtDK + cute::copy(tiled_t2r_dk, tTR_tDK, tTR_rDK); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDK); i++) { + tTR_rDK(i) = mainloop_args.softmax_scale * tTR_rDK(i); + } + + // store tDKgDK + store(tTR_gDK, tTR_rDK, tTR_cDK, select<1,2>(problem_shape)); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dkdv.consumer_release(pipeline_mma_compute_dkdv_consumer_state); + ++pipeline_mma_compute_dkdv_consumer_state; + + } + + + template + CUTLASS_DEVICE void compute( + BlkCoord const& blk_coord, + BlkOffset const& blk_offset, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + EpilogueArguments const& epilogue_args, + TensorStorage& shared_tensors, + PipelineLoadComputeLSE& pipeline_load_compute_lse, + typename PipelineLoadComputeLSE::PipelineState& pipeline_load_compute_lse_consumer_state, + PipelineLoadComputeSumOdO& pipeline_load_compute_sum_odo, + typename PipelineLoadComputeSumOdO::PipelineState& pipeline_load_compute_sum_odo_consumer_state, + PipelineMmaComputeS& pipeline_mma_compute_s, + typename PipelineMmaComputeS::PipelineState& pipeline_mma_compute_s_consumer_state, + PipelineMmaComputeDP& pipeline_mma_compute_dp, + typename PipelineMmaComputeDP::PipelineState& pipeline_mma_compute_dp_consumer_state, + PipelineComputeMmaP& pipeline_compute_mma_p, + typename PipelineComputeMmaP::PipelineState& pipeline_compute_mma_p_producer_state, + PipelineComputeMmaDS& pipeline_compute_mma_ds, + typename PipelineComputeMmaDS::PipelineState& pipeline_compute_mma_ds_producer_state, + PipelineMmaComputeDKDV& pipeline_mma_compute_dkdv, + typename PipelineMmaComputeDKDV::PipelineState& pipeline_mma_compute_dkdv_consumer_state) { + + + auto [Q, K, D, D_VO, HB] = problem_shape; + + // in tmem, S & P overlap + // and dP and dQ overlap + + // there are two compute wg's that cooperatively compute softmax + // they are striped by this tmem atom, i.e. wg0 has 16 elems, then wg1 etc + + auto load_op = SM100_TMEM_LOAD_16dp32b32x{}; + + Tensor tSTtST = partition_fragment_C(TiledMmaQK{}, select<0,1>(TileShapeQK{}))(make_coord(_,_),_0{},_0{}); + tSTtST.data() = TmemAllocation::kS; + + Tensor tDPTtDPT = partition_fragment_C(TiledMmaDOV{}, select<0,1>(TileShapeDOV{}))(make_coord(_,_),_0{},_0{}); + tDPTtDPT.data() = TmemAllocation::kDP; + + Tensor cST = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor cDPT = make_identity_tensor(take<0,2>(TileShapeDOV{})); + Tensor cPT = make_identity_tensor(take<0,2>(TileShapeQK{})); + + constexpr int kNumWarpgroups = kNumComputeWarps / 4; + int dp_idx = threadIdx.x % 128; + int wg_idx = (threadIdx.x % (kNumComputeWarps * NumThreadsPerWarp)) / 128; + auto tiled_t2r = make_tmem_copy(load_op, tSTtST); + auto thread_t2r = tiled_t2r.get_slice(dp_idx); + + auto split_wg = [&](auto const& t) { + if constexpr (decltype(size<1>(t))::value > 1) { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t)))); + return p(_, make_coord(wg_idx, _), _); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), make_shape(Int{}, size<1>(t) / Int{}), size<2>(t), size<3>(t)))); + return p(_, make_coord(wg_idx, _), _, _); + } + } + else { + if constexpr (decltype(rank(t))::value == 3) { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), make_shape(Int{}, size<2>(t) / Int{})))); + return p(_, _, make_coord(wg_idx, _)); + } + else { + auto p = t.compose(make_layout(make_shape(size<0>(t), size<1>(t), size<2>(t), make_shape(Int{}, size<3>(t) / Int{})))); + return p(_, _, _, make_coord(wg_idx, _)); + } + } + }; + + Tensor tTR_cST_p = thread_t2r.partition_D(cST); + Tensor tTR_cST = split_wg(tTR_cST_p); + Tensor tTR_rST = make_tensor(shape(tTR_cST)); + Tensor tTR_tST = split_wg(thread_t2r.partition_S(tSTtST)); + + Tensor tTR_cDPT_p = thread_t2r.partition_D(cDPT); + Tensor tTR_cPT_p = thread_t2r.partition_D(cPT); + Tensor tTR_cDPT = split_wg(tTR_cDPT_p); + Tensor tTR_rDPT = make_tensor(shape(tTR_cDPT)); + Tensor tTR_tDPT = split_wg(thread_t2r.partition_S(tDPTtDPT)); + + Tensor sLSE = make_tensor(make_smem_ptr(shared_tensors.smem_lse.begin()), SmemLayoutLSE{}); + Tensor sSumOdO = make_tensor(make_smem_ptr(shared_tensors.smem_sum_odo.begin()), SmemLayoutSumOdO{}); + + bool is_residual_k = get<1>(blk_coord) * TileShapeK{} + TileShapeK{} >= get<1>(problem_shape); + int last_iter = iter_count - 1 + iter_index; + + CUTLASS_PRAGMA_NO_UNROLL + while (iter_count > 0) { + // wait for S and P + pipeline_mma_compute_s.consumer_wait(pipeline_mma_compute_s_consumer_state); + pipeline_compute_mma_p.producer_acquire(pipeline_compute_mma_p_producer_state); + // wait for LSE + pipeline_load_compute_lse.consumer_wait(pipeline_load_compute_lse_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + bool leading_causal_masking = false; + if constexpr (std::is_base_of_v, Mask>) { + leading_causal_masking = warp_uniform(iter_index == get<1>(blk_coord)); + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + int kv_left = get<1>(blk_coord) * TileShapeK{}; + int kv_right = kv_left + TileShapeK{} - 1; + int q_left = iter_index * TileShapeQ{} + offset; + int q_right = q_left + TileShapeQ{} - 1; + + leading_causal_masking = warp_uniform(!((q_left > kv_right) || (q_right < kv_left))); + } + bool trailing_residual_masking = false; + if constexpr (std::is_base_of_v) { + trailing_residual_masking = warp_uniform((iter_index == last_iter) || is_residual_k); + } + + dispatch_bool(leading_causal_masking || trailing_residual_masking, [&](auto is_masked_tile) { + + // compute P = softmax(S, LSE) + cute::copy(tiled_t2r, tTR_tST, tTR_rST); + + if constexpr (decltype(is_masked_tile)::value) { + Mask{}.apply_mask(tTR_rST, [&](int i) { + auto c_transpose = tTR_cST(i); + return make_coord(get<0>(c_transpose) + iter_index * TileShapeQ{}, get<1>(c_transpose) + get<1>(blk_coord) * TileShapeK{}); + }, problem_shape); + } + + ElementAcc log2_e = static_cast(M_LOG2E); + float2 softmax_scale_log2_e; + softmax_scale_log2_e.x = mainloop_args.softmax_scale * log2_e; + softmax_scale_log2_e.y = mainloop_args.softmax_scale * log2_e; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rST); i += 2) { + float2 acc; + float2 lse; + float2 out; + acc.x = tTR_rST(i); + acc.y = tTR_rST(i + 1); + lse.x = sLSE(get<0>(tTR_cST(i)), pipeline_load_compute_lse_consumer_state.index()); + lse.y = sLSE(get<0>(tTR_cST(i+1)), pipeline_load_compute_lse_consumer_state.index()); + cute::fma(out, softmax_scale_log2_e, acc, lse); + tTR_rST(i) = ::exp2f(out.x); + tTR_rST(i+1) = ::exp2f(out.y); + } + + auto tRT_rST = quantize(tTR_rST); + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}) + (_, _, _, pipeline_compute_mma_p_producer_state.index()); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransformBarrier + ).arrive_and_wait(); + + auto sP_pi = as_position_independent_swizzle_tensor(sP); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sP_pi_slice_p = sP_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape(tTR_cPT_p))); + auto sP_pi_slice = split_wg(sP_pi_slice_p); + copy_aligned(tRT_rST, sP_pi_slice); + }); + + // notify for P + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_p.producer_commit(pipeline_compute_mma_p_producer_state); + ++pipeline_compute_mma_p_producer_state; + // release S + pipeline_mma_compute_s.consumer_release(pipeline_mma_compute_s_consumer_state); + ++pipeline_mma_compute_s_consumer_state; + // release LSE + pipeline_load_compute_lse.consumer_release(pipeline_load_compute_lse_consumer_state); + ++pipeline_load_compute_lse_consumer_state; + + // wait for OdO + pipeline_load_compute_sum_odo.consumer_wait(pipeline_load_compute_sum_odo_consumer_state); + // wait for dP + pipeline_mma_compute_dp.consumer_wait(pipeline_mma_compute_dp_consumer_state); + + // wait for dS + // in principle, we could defer waiting for dS, and move in the freeing of dP + // however, that would force us to keep dS in registers longer + pipeline_compute_mma_ds.producer_acquire(pipeline_compute_mma_ds_producer_state); + + // compute dS = dsoftmax(P, dP, sum_OdO) + cute::copy(tiled_t2r, tTR_tDPT, tTR_rDPT); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rDPT); i += 2) { + float2 st; + st.x = tTR_rST(i); + st.y = tTR_rST(i+1); + float2 dpt; + dpt.x = tTR_rDPT(i); + dpt.y = tTR_rDPT(i+1); + float2 odo; + odo.x = sSumOdO(get<0>(tTR_cDPT(i)), pipeline_load_compute_sum_odo_consumer_state.index()); + odo.y = sSumOdO(get<0>(tTR_cDPT(i+1)), pipeline_load_compute_sum_odo_consumer_state.index()); + float2 dif; + // sum odo is negated during preprocess + cute::add(dif, dpt, odo); + float2 out; + cute::mul(out, dif, st); + tTR_rDPT(i) = out.x; + tTR_rDPT(i+1) = out.y; + } + + auto tTR_rDST = quantize(tTR_rDPT); + + // release dP + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_compute_dp.consumer_release(pipeline_mma_compute_dp_consumer_state); + ++pipeline_mma_compute_dp_consumer_state; + + Tensor sDS = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_ds_t.begin()), SmemLayoutDST{}) + (_, _, _, pipeline_compute_mma_ds_producer_state.index()); + + auto thread_layout = make_ordered_layout( + make_shape(_64{}, _32{}, _2{}, _2{}), + make_stride(_3{}, _0{}, _1{}, _2{}) + ); + auto sDS_pi = as_position_independent_swizzle_tensor(sDS); + auto sDS_pi_slice_p = sDS_pi.compose(thread_layout)(((dp_idx/32) * 16) + (dp_idx % 16) , _, (dp_idx % 32 / 16), _).compose(make_layout(shape (tTR_cDPT_p))); + auto sDS_pi_slice = split_wg(sDS_pi_slice_p); + + copy_aligned(tTR_rDST, sDS_pi_slice); + + // notify for dS + cutlass::arch::fence_view_async_shared(); + pipeline_compute_mma_ds.producer_commit(pipeline_compute_mma_ds_producer_state); + ++pipeline_compute_mma_ds_producer_state; + // release OdO + pipeline_load_compute_sum_odo.consumer_release(pipeline_load_compute_sum_odo_consumer_state); + ++pipeline_load_compute_sum_odo_consumer_state; + + iter_count -= 1; + iter_index += 1; + } + + epilogue( + blk_coord, blk_offset, problem_shape, mainloop_args, epilogue_args, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + } + + template + CUTLASS_DEVICE void reduce( + BlkCoord const& blk_coord, + ProblemShape_ const& problem_shape, + int iter_index, + int iter_count, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineMmaReduceDQ& pipeline_mma_reduce_dq, + typename PipelineMmaReduceDQ::PipelineState& pipeline_mma_reduce_dq_consumer_state, + PipelineReduceTmaStore& pipeline_reduce_tma_store, + typename PipelineReduceTmaStore::PipelineState& pipeline_reduce_tma_store_producer_state) { + + using X = Underscore; + + auto [Q, K, D, D_VO, HB] = problem_shape; + + auto [blk_coord_q, blk_coord_k, blk_coord_d, blk_coord_dv, blk_coord_batch] = blk_coord; + + // must match TileShapeDQ + auto load_op = SM100_TMEM_LOAD_16dp32b16x{}; + + auto tDQtDQ = partition_fragment_C(TiledMmaDSK{}, select<0,1>(TileShapeDSK{}))(make_coord(_,_),_0{},_0{}); + tDQtDQ.data() = TmemAllocation::kDQ; + + Tensor mDQ = mainloop_params.tma_red_dq.get_tma_tensor(make_shape(Q, D, HB)); + auto gDQ = local_tile(mDQ, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}) + (_, _, _, _0{}, blk_coord_batch); + + Tensor cDQ = make_identity_tensor(take<0,2>(TileShapeDSK{})); + + Tensor sDQ = make_tensor(make_smem_ptr(shared_tensors.smem_dq.begin()), SmemLayoutDQ{}); + + int thread_idx = threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp); + auto tiled_t2r = make_tmem_copy(load_op, tDQtDQ); + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + + Tensor tTR_cDQ = thread_t2r.partition_D(cDQ); + Tensor tTR_gDQ = thread_t2r.partition_D(gDQ); + Tensor tTR_sDQ = thread_t2r.partition_D(sDQ); + Tensor tTR_tDQ = thread_t2r.partition_S(tDQtDQ); + + auto block_tma = mainloop_params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQcDQ = block_tma.partition_S(cDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int lane_predicate = (threadIdx.x % (kNumReduceWarps * NumThreadsPerWarp)) == 0; + + while (iter_count > 0) { + pipeline_mma_reduce_dq.consumer_wait(pipeline_mma_reduce_dq_consumer_state); + + Tensor tTR_rDQ = make_tensor(shape(tTR_cDQ)); + + // load dQ from tmem to rmem + cute::copy(tiled_t2r, tTR_tDQ, tTR_rDQ); + + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_reduce_dq.consumer_release(pipeline_mma_reduce_dq_consumer_state); + ++pipeline_mma_reduce_dq_consumer_state; + + // we don't have enough smem to dump it all to smem, so we do it in stages + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tTR_cDQ); i++) { + if (lane_predicate) { + pipeline_reduce_tma_store.producer_acquire(pipeline_reduce_tma_store_producer_state); + } + // wait in all threads for the acquire to complete + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + + cute::copy(tTR_rDQ(_, _, i), tTR_sDQ(_, _, _0{}, pipeline_reduce_tma_store_producer_state.index())); + + // wait for the stores to all be visible to the TMA + cutlass::arch::fence_view_async_shared(); + cutlass::arch::NamedBarrier( + kNumReduceWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::TransposeBarrier + ).arrive_and_wait(); + if (lane_predicate) { + // launch tma store + copy(mainloop_params.tma_red_dq, tDQsDQ(_,_,_0{}, pipeline_reduce_tma_store_producer_state.index()), tDQgDQ(_,_,i,iter_index)); + pipeline_reduce_tma_store.producer_commit(pipeline_reduce_tma_store_producer_state); + } + + ++pipeline_reduce_tma_store_producer_state; + } + + iter_count -= 1; + iter_index += 1; + } + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_v.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_do.get_tma_descriptor()); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + int initializing_warp = 0; + typename PipelineLoadMmaQ::Params pipeline_load_mma_q_params; + if (role == WarpRole::Load) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_q_params.role = PipelineLoadMmaQ::ThreadCategory::Consumer; + } + pipeline_load_mma_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads K in the first iteration + pipeline_load_mma_q_params.transaction_bytes = kTransactionsBytesLoadQ; + pipeline_load_mma_q_params.initializing_warp = initializing_warp++; + PipelineLoadMmaQ pipeline_load_mma_q(shared_storage.pipelines.load_mma_q, pipeline_load_mma_q_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadMmaDO::Params pipeline_load_mma_do_params; + if (role == WarpRole::Load) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Producer; + } + if (role == WarpRole::Mma) { + pipeline_load_mma_do_params.role = PipelineLoadMmaDO::ThreadCategory::Consumer; + } + pipeline_load_mma_do_params.is_leader = lane_predicate && (role == WarpRole::Load); + // Also loads V in the first iteration + pipeline_load_mma_do_params.transaction_bytes = kTransactionsBytesLoadDO; + pipeline_load_mma_do_params.initializing_warp = initializing_warp++; + PipelineLoadMmaDO pipeline_load_mma_do(shared_storage.pipelines.load_mma_do, pipeline_load_mma_do_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineLoadComputeLSE::Params pipeline_load_compute_lse_params; + if (role == WarpRole::Load) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_lse_params.role = PipelineLoadComputeLSE::ThreadCategory::Consumer; + } + pipeline_load_compute_lse_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_lse_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_lse_params.initializing_warp = initializing_warp++; + PipelineLoadComputeLSE pipeline_load_compute_lse( + shared_storage.pipelines.load_compute_lse, + pipeline_load_compute_lse_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineLoadComputeSumOdO::Params pipeline_load_compute_sum_odo_params; + if (role == WarpRole::Load) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_load_compute_sum_odo_params.role = PipelineLoadComputeSumOdO::ThreadCategory::Consumer; + } + pipeline_load_compute_sum_odo_params.producer_arv_count = NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.consumer_arv_count = kNumComputeWarps * NumThreadsPerWarp; + pipeline_load_compute_sum_odo_params.initializing_warp = initializing_warp++; + PipelineLoadComputeSumOdO pipeline_load_compute_sum_odo( + shared_storage.pipelines.load_compute_sum_odo, + pipeline_load_compute_sum_odo_params, + /*barrier init*/ cute::true_type{}); + + typename PipelineMmaComputeS::Params pipeline_mma_compute_s_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_s_params.role = PipelineMmaComputeS::ThreadCategory::Consumer; + } + pipeline_mma_compute_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_s_params.initializing_warp = initializing_warp++; + PipelineMmaComputeS pipeline_mma_compute_s( + shared_storage.pipelines.mma_compute_s, + pipeline_mma_compute_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDP::Params pipeline_mma_compute_dp_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dp_params.role = PipelineMmaComputeDP::ThreadCategory::Consumer; + } + pipeline_mma_compute_dp_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dp_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDP pipeline_mma_compute_dp( + shared_storage.pipelines.mma_compute_dp, + pipeline_mma_compute_dp_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaReduceDQ::Params pipeline_mma_reduce_dq_params; + if (role == WarpRole::Mma) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Producer; + } + if (role == WarpRole::Reduce) { + pipeline_mma_reduce_dq_params.role = PipelineMmaReduceDQ::ThreadCategory::Consumer; + } + pipeline_mma_reduce_dq_params.consumer_arv_count = kNumReduceWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_reduce_dq_params.initializing_warp = initializing_warp++; + PipelineMmaReduceDQ pipeline_mma_reduce_dq( + shared_storage.pipelines.mma_reduce_dq, + pipeline_mma_reduce_dq_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaP::Params pipeline_compute_mma_p_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_p_params.role = PipelineComputeMmaP::ThreadCategory::Producer; + } + pipeline_compute_mma_p_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_p_params.consumer_arv_count = 1; + pipeline_compute_mma_p_params.initializing_warp = initializing_warp++; + PipelineComputeMmaP pipeline_compute_mma_p( + shared_storage.pipelines.compute_mma_p, + pipeline_compute_mma_p_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineComputeMmaDS::Params pipeline_compute_mma_ds_params; + if (role == WarpRole::Mma) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Consumer; + } + if (role == WarpRole::Compute) { + pipeline_compute_mma_ds_params.role = PipelineComputeMmaDS::ThreadCategory::Producer; + } + pipeline_compute_mma_ds_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_compute_mma_ds_params.consumer_arv_count = 1; + pipeline_compute_mma_ds_params.initializing_warp = initializing_warp++; + PipelineComputeMmaDS pipeline_compute_mma_ds( + shared_storage.pipelines.compute_mma_ds, + pipeline_compute_mma_ds_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineMmaComputeDKDV::Params pipeline_mma_compute_dkdv_params; + if (role == WarpRole::Mma) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Producer; + } + if (role == WarpRole::Compute) { + pipeline_mma_compute_dkdv_params.role = PipelineMmaComputeDKDV::ThreadCategory::Consumer; + } + pipeline_mma_compute_dkdv_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp; + pipeline_mma_compute_dkdv_params.initializing_warp = initializing_warp++; + PipelineMmaComputeDKDV pipeline_mma_compute_dkdv( + shared_storage.pipelines.mma_compute_dkdv, + pipeline_mma_compute_dkdv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + PipelineReduceTmaStore pipeline_reduce_tma_store; + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_mma_q.init_masks(ClusterShape{}); + pipeline_load_mma_do.init_masks(ClusterShape{}); + pipeline_mma_compute_s.init_masks(ClusterShape{}); + pipeline_mma_compute_dp.init_masks(ClusterShape{}); + pipeline_mma_reduce_dq.init_masks(ClusterShape{}); + pipeline_compute_mma_p.init_masks(ClusterShape{}); + pipeline_compute_mma_ds.init_masks(ClusterShape{}); + pipeline_mma_compute_dkdv.init_masks(ClusterShape{}); + + typename decltype(pipeline_load_mma_q)::PipelineState pipeline_load_mma_q_consumer_state; + typename decltype(pipeline_load_mma_do)::PipelineState pipeline_load_mma_do_consumer_state; + typename decltype(pipeline_load_compute_lse)::PipelineState pipeline_load_compute_lse_consumer_state; + typename decltype(pipeline_load_compute_sum_odo)::PipelineState pipeline_load_compute_sum_odo_consumer_state; + typename decltype(pipeline_mma_compute_s)::PipelineState pipeline_mma_compute_s_consumer_state; + typename decltype(pipeline_mma_compute_dp)::PipelineState pipeline_mma_compute_dp_consumer_state; + typename decltype(pipeline_mma_reduce_dq)::PipelineState pipeline_mma_reduce_dq_consumer_state; + typename decltype(pipeline_compute_mma_p)::PipelineState pipeline_compute_mma_p_consumer_state; + typename decltype(pipeline_compute_mma_ds)::PipelineState pipeline_compute_mma_ds_consumer_state; + typename decltype(pipeline_mma_compute_dkdv)::PipelineState pipeline_mma_compute_dkdv_consumer_state; + + auto pipeline_load_mma_q_producer_state = make_producer_start_state(); + auto pipeline_load_mma_do_producer_state = make_producer_start_state(); + auto pipeline_load_compute_lse_producer_state = make_producer_start_state(); + auto pipeline_load_compute_sum_odo_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_s_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dp_producer_state = make_producer_start_state(); + auto pipeline_mma_reduce_dq_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_p_producer_state = make_producer_start_state(); + auto pipeline_compute_mma_ds_producer_state = make_producer_start_state(); + auto pipeline_mma_compute_dkdv_producer_state = make_producer_start_state(); + auto pipeline_reduce_tma_store_producer_state = make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + auto blk_coord = make_coord(_0{}, blockIdx.x, _0{}, _0{}, make_coord(blockIdx.y, blockIdx.z)); + auto [problem_shape, blk_offset] = apply_variable_length_offset( + params.problem_shape, + blk_coord + ); + int iter_count = ceil_div(get<0>(problem_shape), TileShapeQ{}); + int iter_start = 0; + if constexpr (std::is_base_of_v, Mask>) { + iter_start = (get<1>(blk_coord) * TileShapeK{}) / TileShapeQ{}; + } else if constexpr (std::is_base_of_v, Mask>) { + int offset = get<1>(problem_shape) - get<0>(problem_shape); + iter_start = max(0, (int(get<1>(blk_coord) * TileShapeK{}) - offset) / (int)TileShapeQ{}); + } + if (get<1>(blk_coord) * TileShapeK{} >= get<1>(problem_shape)) { + return; + } + iter_count -= iter_start; + + if (iter_count <= 0) { + epilogue_clear( + blk_coord, + blk_offset, + problem_shape, + params.mainloop, + params.epilogue + ); + return; + } + + if (role == WarpRole::Load) { + warpgroup_reg_set(); + + load( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_producer_state, + pipeline_load_mma_do, pipeline_load_mma_do_producer_state, + pipeline_load_compute_lse, pipeline_load_compute_lse_producer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_producer_state + ); + + } + else if (role == WarpRole::Mma) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + mma( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + shared_storage.tensors, + pipeline_load_mma_q, pipeline_load_mma_q_consumer_state, + pipeline_load_mma_do, pipeline_load_mma_do_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_producer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_producer_state, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_producer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_consumer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_consumer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_producer_state + ); + + } + else if (role == WarpRole::Compute) { + warpgroup_reg_set(); + + compute( + blk_coord, + blk_offset, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.epilogue, + shared_storage.tensors, + pipeline_load_compute_lse, pipeline_load_compute_lse_consumer_state, + pipeline_load_compute_sum_odo, pipeline_load_compute_sum_odo_consumer_state, + pipeline_mma_compute_s, pipeline_mma_compute_s_consumer_state, + pipeline_mma_compute_dp, pipeline_mma_compute_dp_consumer_state, + pipeline_compute_mma_p, pipeline_compute_mma_p_producer_state, + pipeline_compute_mma_ds, pipeline_compute_mma_ds_producer_state, + pipeline_mma_compute_dkdv, pipeline_mma_compute_dkdv_consumer_state + ); + + cutlass::arch::NamedBarrier( + kNumComputeWarps * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier + ).arrive_and_wait(); + + if (warp_idx % kNumComputeWarps == 0) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Reduce) { + warpgroup_reg_set(); + + reduce( + blk_coord, + problem_shape, + iter_start, + iter_count, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_mma_reduce_dq, pipeline_mma_reduce_dq_consumer_state, + pipeline_reduce_tma_store, pipeline_reduce_tma_store_producer_state + ); + + pipeline_reduce_tma_store.producer_tail(pipeline_reduce_tma_store_producer_state); + } + else { + warpgroup_reg_set(); + + /* no-op */ + + } +#endif + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static dim3 get_grid_shape(Params const& params) { + auto [Q, K, D, D_VO, HB] = params.problem_shape; + auto [H, B] = HB; + dim3 grid(ceil_div(K, TileShapeK{}), H, B); + return grid; + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a88b9a871098f27cf569dc20434c1f35a103d9ff --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_fwd_kernel_tma_warpspecialized.hpp @@ -0,0 +1,640 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "kernel/fmha_causal_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" +#include "collective/fmha_common.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 32 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + + +struct Sm100MlaFwdCtxKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + int wg_idx = warp_idx / 4; // warp_idx + if (wg_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (wg_idx == 1) return WarpRole::Softmax1; // 4 - 7 + if (wg_idx == 2) return WarpRole::Correction; // 8 - 11 + if (warp_idx == 12) return WarpRole::MMA; // 12 + if (warp_idx == 13) return WarpRole::Load; // 13 + if (warp_idx == 14) return WarpRole::Epilogue; // 14 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 4; + static const int NumWarpsCorrection = 4; + static const int NumWarpsEpilogue = 1; + static const int NumWarpsLoad = 1; + + static const bool kDebugUsingPrintf = false; + static const int NumRegsSoftmax = 184; + static const int NumRegsCorrection = 96 - (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsOther = 48 + (kDebugUsingPrintf ? 16 : 0); + static const int NumRegsEmpty = 24; + + static const int NumWarps = 16; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaCtxKernelWarpspecializedSchedule +> +struct Sm100FmhaFwdKernelTmaWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = ProblemShapeIn; + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static_assert(NumWarpsEpilogue == CollectiveEpilogue::NumWarpsEpilogue); + static_assert(NumWarpsLoad == CollectiveEpilogue::NumWarpsLoad); + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + static constexpr bool IsMla = std::is_same_v; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + using UnionType = union { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using StructType = struct { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + static constexpr bool IsPersistent = std::is_same_v || std::is_same_v; + using MainloopEpilogueStorage = std::conditional_t, + StructType>, + UnionType>; + + MainloopEpilogueStorage mainloop_epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + struct Arguments { + ProblemShape problem_shape; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + cutlass::KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_shape; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + return apply_variable_length(params.problem_shape, batch_idx); + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + auto get_epilogue_storage = [&]() { + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + return reinterpret_cast(shared_storage.mainloop_epilogue.mainloop.smem_o.data()); + } else { + return &shared_storage.mainloop_epilogue.epilogue; + } + }; + typename CollectiveEpilogue::TensorStorage & epilogue_storage = *get_epilogue_storage(); + + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_q_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadQ; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.is_leader = lane_predicate && (role == WarpRole::Load); + pipeline_load_kv_params.transaction_bytes = CollectiveMainloop::TransactionBytesLoadK; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = NumWarpsEpilogue * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue{params.epilogue}; + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + bool has_valid = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + has_valid = true; + + if (get<1>(logical_problem_shape) == 0) { + mainloop.correction_empty( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + params.problem_shape, + epilogue_storage, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + if (has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + bool allocated = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (!allocated) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + allocated = true; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + if constexpr (IsMla && CollectiveMainloop::IsOrderLoadEpilogue) { + cutlass::arch::NamedBarrier::arrive((NumWarpsLoad + NumWarpsEpilogue) * NumThreadsPerWarp, + cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); + } + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + if (get<1>(logical_problem_shape) == 0) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop_epilogue.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + bool has_valid = false; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + has_valid = true; + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + epilogue_storage, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + if(has_valid) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } +#endif + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d59b20ff194c2eb6c0c18b2e0af03047d15601ea --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_gen_kernel_warpspecialized.hpp @@ -0,0 +1,580 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/kernel_hardware_info.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/arch/tmem_allocator_sm100.hpp" + +#include "kernel/fmha_options.hpp" +#include "kernel/fmha_tile_scheduler.hpp" +#include "collective/fmha_fusion.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +using namespace cutlass::fmha::collective; + +struct Sm100FmhaGenKernelWarpspecializedSchedule { + + enum class WarpRole { + Softmax0, + Softmax1, + Correction, + MMA, + Load, + Epilogue, + Empty + }; + + static constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + if (warp_idx == 0) return WarpRole::Softmax0; // 0 - 3 + if (warp_idx == 1) return WarpRole::MMA; // 12 + if (warp_idx == 2 || warp_idx == 3) return WarpRole::Load; // 13 + if (warp_idx == 4) return WarpRole::Softmax1; // 4 - 7 + if (warp_idx == 8) return WarpRole::Correction; // 8 - 11 + return WarpRole::Empty; // 15 + } + + static const int NumWarpsSoftmax = 1; + static const int NumWarpsCorrection = 1; + static const int NumWarpsEpilogue = 0; + static const int NumWarpsLoad = 2; + + static const int NumRegsSoftmax = 192; + static const int NumRegsCorrection = 104; + static const int NumRegsOther = 248; + static const int NumRegsEmpty = 24; + + static const int NumWarps = 12; + +}; + +template< + class ProblemShapeIn, + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class KernelSchedule = Sm100FmhaGenKernelWarpspecializedSchedule +> +struct Sm100FmhaGenKernelWarpspecialized { + + using TileShape = typename CollectiveMainloop::TileShape; + using ProblemShape = decltype(replace<0>(ProblemShapeIn{}, 0)); + + using WarpRole = typename KernelSchedule::WarpRole; + + constexpr WarpRole warp_idx_to_WarpRole(int warp_idx) { + return KernelSchedule::warp_idx_to_WarpRole(warp_idx); + } + + static const int NumWarpsSoftmax = KernelSchedule::NumWarpsSoftmax; + static const int NumWarpsCorrection = KernelSchedule::NumWarpsCorrection; + static const int NumWarpsEpilogue = KernelSchedule::NumWarpsEpilogue; + static const int NumWarpsLoad = KernelSchedule::NumWarpsLoad; + + static const int NumRegsSoftmax = KernelSchedule::NumRegsSoftmax; + static const int NumRegsCorrection = KernelSchedule::NumRegsCorrection; + static const int NumRegsOther = KernelSchedule::NumRegsOther; + static const int NumRegsEmpty = 24; + + static const int NumWarps = KernelSchedule::NumWarps; + + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + + struct SharedStorage { + typename CollectiveMainloop::TensorStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + + struct PipelineStorage { + alignas(16) typename CollectiveMainloop::PipelineQ::SharedStorage load_q; + alignas(16) typename CollectiveMainloop::PipelineKV::SharedStorage load_kv; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s0; + alignas(16) typename CollectiveMainloop::PipelineS::SharedStorage mma_s1; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s0_corr; + alignas(16) typename CollectiveMainloop::PipelineC::SharedStorage s1_corr; + alignas(16) typename CollectiveMainloop::PipelineO::SharedStorage mma_corr; + alignas(16) typename CollectiveMainloop::PipelineE::SharedStorage corr_epi; + alignas(16) typename CollectiveMainloop::OrderBarrierSoftmax::SharedStorage order_s01; + } pipelines; + + uint32_t tmem_base_ptr; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using StrideQOrig = typename CollectiveMainloop::StrideQOrig; + using StrideOOrig = typename CollectiveMainloop::StrideOOrig; + using StrideQ = typename CollectiveMainloop::StrideQ; + using StrideO = typename CollectiveMainloop::StrideO; + using StrideCacheK = typename CollectiveMainloop::StrideCacheK; + using StrideCacheV = typename CollectiveMainloop::StrideCacheV; + using StrideNewK = typename CollectiveMainloop::StrideNewK; + using StrideNewV = typename CollectiveMainloop::StrideNewV; + using Element = typename CollectiveMainloop::Element; + using ElementAcc = typename CollectiveMainloop::ElementAcc; + using ElementOut = typename CollectiveMainloop::ElementOut; + + struct Arguments { + // _1, max_seqlen_k, head_dim, ((h_g, h_kv), b) + ProblemShapeIn problem_shape; + const int* seqlen_kv; + const int* cache_batch_idx; + + const Element* ptr_q; // 1 x D x (H x B) + StrideQOrig dQ; + const Element* ptr_new_k; // 1 x D x (H x B) + StrideNewK dNewK; + const Element* ptr_new_v; // 1 x D x (H x B) + StrideNewV dNewV; + + Element* ptr_cache_k; // seqlen_max x D x (H x B) + StrideCacheK dCacheK; + Element* ptr_cache_v; // seqlen_max x D x (H x B) + StrideCacheV dCacheV; + ElementOut* ptr_o; // 1 x D x (H x B) + StrideOOrig dO; + + cutlass::KernelHardwareInfo hw_info; + + ElementAcc scale_softmax = 0.0f; + }; + + struct Params { + ProblemShape problem_shape; + const int* seqlen_kv; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = NumWarps * cutlass::NumThreadsPerWarp; + using ArchTag = cutlass::arch::Sm100; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return true; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + ProblemShape problem_shape = replace<0>(args.problem_shape, static_cast(get<0>(args.problem_shape))); + CUTE_STATIC_ASSERT_V(get<0>(args.problem_shape) == _1{}); + StrideQ dQ = replace<0>(args.dQ, 0); + StrideO dO = replace<0>(args.dO, 0); + get<0>(problem_shape) = get<3,0,0>(args.problem_shape); + get<3,0,0>(problem_shape) = 1; + get<0>(dQ) = get<2,0,0>(dQ); + get<0>(dO) = get<2,0,0>(dO); + + typename CollectiveMainloop::Arguments mainloop_args { + { + args.cache_batch_idx, + args.ptr_q, dQ, + args.ptr_new_k, args.dNewK, + args.ptr_new_v, args.dNewV, + args.ptr_cache_k, args.dCacheK, + args.ptr_cache_v, args.dCacheV, + }, + args.scale_softmax + }; + + typename CollectiveEpilogue::Arguments epilogue_args { + args.ptr_o, dO, + }; + + return Params{ + problem_shape, + args.seqlen_kv, + CollectiveMainloop::to_underlying_arguments(problem_shape, mainloop_args, workspace), + CollectiveEpilogue::to_underlying_arguments(problem_shape, epilogue_args, workspace), + TileScheduler::to_underlying_arguments(problem_shape, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE auto apply_batch(const Params ¶ms, ProblemShape const& problem_shape, int batch_idx) { + ProblemShape result = problem_shape; + get<1>(result) = params.seqlen_kv[batch_idx]; + if (params.mainloop.load.ptr_new_k != nullptr) { + get<1>(result) += 1; + } + return result; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler{params.tile_scheduler}; + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_WarpRole(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + if (role == WarpRole::Load && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + if (role == WarpRole::Epilogue && lane_predicate) { + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + SharedStorage& shared_storage = *reinterpret_cast(smem); + + typename CollectiveMainloop::PipelineQ::Params pipeline_load_q_params; + if (role == WarpRole::Load) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_q_params.role = CollectiveMainloop::PipelineQ::ThreadCategory::Consumer; + } + pipeline_load_q_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineQ pipeline_load_q( + shared_storage.pipelines.load_q, + pipeline_load_q_params, + ClusterShape{}, cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineKV::Params pipeline_load_kv_params; + if (role == WarpRole::Load) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Producer; + } + if (role == WarpRole::MMA) { + pipeline_load_kv_params.role = CollectiveMainloop::PipelineKV::ThreadCategory::Consumer; + } + pipeline_load_kv_params.producer_arv_count = NumWarpsLoad * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineKV pipeline_load_kv( + shared_storage.pipelines.load_kv, + pipeline_load_kv_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s0_params; + if (role == WarpRole::MMA) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax0) { + pipeline_mma_s0_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s0_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s0( + shared_storage.pipelines.mma_s0, + pipeline_mma_s0_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineS::Params pipeline_mma_s1_params; + if (role == WarpRole::MMA) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::Softmax1) { + pipeline_mma_s1_params.role = CollectiveMainloop::PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s1_params.consumer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineS pipeline_mma_s1( + shared_storage.pipelines.mma_s1, + pipeline_mma_s1_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s0_corr_params; + if (role == WarpRole::Softmax0) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s0_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s0_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s0_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s0_corr( + shared_storage.pipelines.s0_corr, + pipeline_s0_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineC::Params pipeline_s1_corr_params; + if (role == WarpRole::Softmax1) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_s1_corr_params.role = CollectiveMainloop::PipelineC::ThreadCategory::Consumer; + } + pipeline_s1_corr_params.producer_arv_count = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + pipeline_s1_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineC pipeline_s1_corr( + shared_storage.pipelines.s1_corr, + pipeline_s1_corr_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::PipelineO::Params pipeline_mma_corr_params; + if (role == WarpRole::MMA) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::Correction) { + pipeline_mma_corr_params.role = CollectiveMainloop::PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_corr_params.consumer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::PipelineO pipeline_mma_corr( + shared_storage.pipelines.mma_corr, + pipeline_mma_corr_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename CollectiveMainloop::PipelineE::Params pipeline_corr_epi_params; + if (role == WarpRole::Correction) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Producer; + } + if (role == WarpRole::Epilogue) { + pipeline_corr_epi_params.role = CollectiveMainloop::PipelineE::ThreadCategory::Consumer; + } + pipeline_corr_epi_params.producer_arv_count = NumWarpsCorrection * cutlass::NumThreadsPerWarp; + pipeline_corr_epi_params.consumer_arv_count = cute::max(1, NumWarpsEpilogue * cutlass::NumThreadsPerWarp); + typename CollectiveMainloop::PipelineE pipeline_corr_epi( + shared_storage.pipelines.corr_epi, + pipeline_corr_epi_params, + /*barrier init*/ cute::true_type{}); + + typename CollectiveMainloop::OrderBarrierSoftmax::Params params_order_s01; + params_order_s01.group_id = role == WarpRole::Softmax1 ? 1 : 0; + params_order_s01.group_size = NumWarpsSoftmax * cutlass::NumThreadsPerWarp; + typename CollectiveMainloop::OrderBarrierSoftmax order_s01( + shared_storage.pipelines.order_s01, params_order_s01); + + TmemAllocator tmem_allocator; + + __syncthreads(); + + pipeline_load_q.init_masks(ClusterShape{}); + pipeline_load_kv.init_masks(ClusterShape{}); + pipeline_mma_s0.init_masks(ClusterShape{}); + pipeline_mma_s1.init_masks(ClusterShape{}); + pipeline_mma_corr.init_masks(ClusterShape{}); + + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_consumer_state; + typename CollectiveMainloop::PipelineQ::PipelineState pipeline_load_q_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_consumer_state; + typename CollectiveMainloop::PipelineKV::PipelineState pipeline_load_kv_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s0_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_consumer_state; + typename CollectiveMainloop::PipelineS::PipelineState pipeline_mma_s1_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s0_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_consumer_state; + typename CollectiveMainloop::PipelineC::PipelineState pipeline_s1_corr_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_consumer_state; + typename CollectiveMainloop::PipelineE::PipelineState pipeline_corr_epi_producer_state = cutlass::make_producer_start_state(); + + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_consumer_state; + typename CollectiveMainloop::PipelineO::PipelineState pipeline_mma_corr_producer_state = cutlass::make_producer_start_state(); + + CollectiveMainloop mainloop; + CollectiveEpilogue epilogue(params.epilogue); + + if (role == WarpRole::Softmax0 || role == WarpRole::Softmax1) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + bool is_softmax_0 = role == WarpRole::Softmax0; + + mainloop.softmax( + is_softmax_0 ? 0 : 1, blk_coord, + params.mainloop, logical_problem_shape, + is_softmax_0 ? pipeline_mma_s0 : pipeline_mma_s1, + is_softmax_0 ? pipeline_mma_s0_consumer_state : pipeline_mma_s1_consumer_state, + is_softmax_0 ? pipeline_s0_corr : pipeline_s1_corr, + is_softmax_0 ? pipeline_s0_corr_producer_state : pipeline_s1_corr_producer_state, + order_s01 + ); + + } + } + else if (role == WarpRole::Correction) { + cutlass::arch::warpgroup_reg_dealloc(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.correction( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.epilogue, + pipeline_s0_corr, pipeline_s0_corr_consumer_state, + pipeline_s1_corr, pipeline_s1_corr_consumer_state, + pipeline_mma_corr, pipeline_mma_corr_consumer_state, + pipeline_corr_epi, pipeline_corr_epi_producer_state, + epilogue + ); + + + } + + if constexpr (NumWarpsEpilogue == 0) { + static_assert(NumWarpsCorrection == 1); + + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::MMA) { + warpgroup_reg_set(); + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + + mainloop.mma( + blk_coord, + params.mainloop, logical_problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_consumer_state, + pipeline_load_kv, pipeline_load_kv_consumer_state, + pipeline_mma_s0, pipeline_mma_s0_producer_state, + pipeline_mma_s1, pipeline_mma_s1_producer_state, + pipeline_mma_corr, pipeline_mma_corr_producer_state + ); + + + } + } + else if (role == WarpRole::Load) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + mainloop.load( + blk_coord, logical_problem_shape, + params.mainloop, params.problem_shape, + shared_storage.mainloop, + pipeline_load_q, pipeline_load_q_producer_state, + pipeline_load_kv, pipeline_load_kv_producer_state + ); + + } + } + else if (role == WarpRole::Epilogue) { + warpgroup_reg_set(); + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + + auto logical_problem_shape = apply_batch(params, + params.problem_shape, get<2,1>(blk_coord)); + + if (get<0>(blk_coord) * get<0>(TileShape{}) >= get<0>(logical_problem_shape)) { + continue; + } + + epilogue.store( + blk_coord, logical_problem_shape, + params.epilogue, params.problem_shape, + shared_storage.epilogue, + pipeline_corr_epi, pipeline_corr_epi_consumer_state + ); + + } + + static_assert(NumWarpsEpilogue <= 1); + if constexpr (NumWarpsEpilogue == 1) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + + } + else if (role == WarpRole::Empty) { + warpgroup_reg_set(); + + /* no-op, donate regs and exit */ + } +#endif + } + +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2418bcf804a6bb57fc3358a741c3b333f4d19cfb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_reduction.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/arch.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; +template< + class ElementOut, + class ElementAcc, + class ElementScale, + size_t kNumHeads, + size_t kHeadDimLatent, + int kMaxSplits +> +struct Sm100FmhaMlaReductionKernel { + + static const int SharedStorageSize = 0; + static const int MaxThreadsPerBlock = 128; + static const int MinBlocksPerMultiprocessor = 1; + + using ArchTag = cutlass::arch::Sm100; + + static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0); + struct Arguments { + ElementAcc* ptr_oaccum = nullptr; + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_lseaccum = nullptr; + ElementAcc* ptr_lse = nullptr; + ElementScale scale = 1.f; + int num_batches = 0; + int split_kv = -1; + int dim_k = -1; + int* ptr_seq = nullptr; + int* ptr_split_kv = nullptr; + int tile_shape_s = 128; + }; + using Params = Arguments; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse, + args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq, + args.ptr_split_kv, args.tile_shape_s}; + } + + static size_t get_workspace_size(Arguments const& /*args*/) { + return 0; + } + + static Status initialize_workspace( + Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) { + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return dim3(kNumHeads, 1, params.num_batches); + } + + static dim3 get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + static bool can_implement(Arguments const& args) { + if (args.num_batches <= 0) return false; + if (args.split_kv <= 0) return false; + return true; + } + + CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) { + if (params.split_kv <= 1) return; + + auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z); + + auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)]; + auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)]; + auto k_tile_total = ceil_div(dim_k, params.tile_shape_s); + auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv); + local_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + + if (local_split_kv == 1) return; + + __shared__ ElementAcc sLseScale[kMaxSplits]; + const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord); + const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord); + + Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum), + make_shape(params.split_kv), Stride>{}); + + Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse), + Shape<_1>{}, Stride<_1>{}); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + if (warp_idx == 0) { + constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32); + + ElementAcc local_lse[kNLsePerThread]; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits::infinity(); + } + + ElementAcc lse_max = -std::numeric_limits::infinity(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + lse_max = fmax(local_lse[i], lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + lse_max = fmax(__shfl_xor_sync(0xffffffff, lse_max, offset), lse_max); + } + + lse_max = __shfl_sync(0xffffffff, lse_max, 0); + + ElementAcc sum_lse = 0; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + sum_lse = sum_lse + expf(local_lse[i] - lse_max); + } + + CUTLASS_PRAGMA_UNROLL + for (int offset = 16; offset >= 1; offset /= 2) { + sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset); + } + + sum_lse = __shfl_sync(0xffffffff, sum_lse, 0); + + ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : logf(sum_lse) + lse_max; + if (threadIdx.x == 0 and params.ptr_lse != nullptr) { + gLSE(0) = global_lse; + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNLsePerThread; ++i) { + const int split = i * 32 + threadIdx.x; + if (split < local_split_kv) { + sLseScale[split] = expf(local_lse[i] - global_lse); + } + } + } + __syncthreads(); + + constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock; + const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord)); + Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum), + Shape>{}, Stride<_1>{}); + ElementAcc local_val[Elements] = {0}; + for (int split = 0; split < local_split_kv; ++split) { + ElementAcc lse_scale = sLseScale[split]; + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i); + } + gOaccum.data() = gOaccum.data() + kHeadDimLatent; + } + auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent; + Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape>{}, Stride<_1>{}); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < Elements; ++i) { + gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast(local_val[i]); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e9edb90e571a58d7c1ec954e5f7c8bb11489fcf4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_fmha_mla_tma_warpspecialized.hpp @@ -0,0 +1,2237 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cute/arch/simd_sm100.hpp" +#include "cutlass/barrier.h" + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory_sm80.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "gather_tensor.hpp" // from examples/common +#include "common/pow_2.hpp" +#include "sm100_mla_tile_scheduler.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class TileShape, + class Element_, + class ElementAcc_, + class ElementOut_, + class ElementLSE_, + class TileScheduler, +#ifdef CPASYNC + bool kIsCpAsync = true +#else + bool kIsCpAsync = false +#endif +> +struct Sm100FmhaMlaKernelTmaWarpspecialized { + + using Element = Element_; + using ElementAcc = ElementAcc_; + using ElementOut = ElementOut_; + using ElementLSE = ElementLSE_; + + // only 2Sm mode is supported + static const bool kIs2Sm = true; + static const int MaxThreadsPerBlock = 256; + static const int MinBlocksPerMultiprocessor = 1; + static const int TotalSNum = 2; + static const int TotalPNum = 2; + using ArchTag = cutlass::arch::Sm100; + + using ClusterShape = cute::conditional_t, Shape<_1, _1, _1>>; + + using TileShapeH = tuple_element_t<0, TileShape>; + using TileShapeS = tuple_element_t<1, TileShape>; + using TileShapeD = tuple_element_t<2, TileShape>; + + using TileShapeL = tuple_element_t<0, TileShapeD>; + using TileShapeR = tuple_element_t<1, TileShapeD>; + static_assert(TileShapeL{} % TileShapeR{} == 0, "Rope head dim must divide latent head dim"); + + using ProblemShape = Shape; + using TensorStride = Stride; + using TmemAllocator = cute::conditional_t; + + static_assert(TileShapeH{} == 128); + static const int kWarpsInN = kIs2Sm ? 2 : 1; + + static const int kNumComputeWarps = 4; + static const int kNumLoadWarps = kIsCpAsync ? 2 : 1; + + enum class WarpRole { + kMma = 0x1, kLoad = 0x2, kCompute = 0x3, kLoadPageTable = 0x4, kEmpty=0x0 + }; + + static const long long unsigned int kWarpAssignment = kIsCpAsync ? 0x4221'3333ull : 0x0021'3333ull; + + static CUTLASS_DEVICE WarpRole warp_idx_to_role(int warp_idx) { + return static_cast((kWarpAssignment >> (4 * warp_idx)) & 0xF); + } + + static const int Alignment = 128 / sizeof_bits_v; + static const int AlignmentOut = 128 / sizeof_bits_v; + + using TileShapeQK = Shape; + static const int StagesQK = 24 / sizeof(Element); // free parameter + static const int IterationsQKLatent = decltype(TileShapeL{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQKRope = decltype(TileShapeR{} / get<2>(TileShapeQK{}))::value; + static const int IterationsQK = IterationsQKLatent + IterationsQKRope; + + using Schedule = cute::conditional_t; + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TensorStride, Alignment, + ElementAcc, + TileShapeQK, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using CtaShapeQK = typename CollectiveMmaQK::CtaShape_MNK; + + // chosen for unified smem staging between K and V + using TileShapePV = Shape; + using TransposeTensorStride = decltype(select<1,0,2>(TensorStride{})); + static const int StagesPV = StagesQK; // not sure why, but must be at least two. check pipes + static const int IterationsPV_K = decltype(TileShapeS{} / get<2>(TileShapePV{}))::value; + static const int IterationsPV_N = decltype(TileShapeL{} / get<1>(TileShapePV{}))::value; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm100, cutlass::arch::OpClassTensorOp, + Element, TensorStride, Alignment, + Element, TransposeTensorStride, Alignment, + ElementAcc, + TileShapePV, ClusterShape, cutlass::gemm::collective::StageCount, + Schedule>::CollectiveOp; + using CtaShapePV = typename CollectiveMmaPV::CtaShape_MNK; + static_assert(std::is_same_v); + + using TiledMmaPV = typename CollectiveMmaPV::TiledMma; + + using AtomThrShapeMNK = typename CollectiveMmaQK::AtomThrShapeMNK; + static_assert(typename CollectiveMmaQK::AtomThrShapeMNK{} == typename CollectiveMmaPV::AtomThrShapeMNK{}, "schedule must match"); + + static const int StagesPageTable = kIsCpAsync ? StagesPV : 1; + + // pipelines from load to mma, PipelineTmaUmmaAsync, stages tbd + // use expect_tx for Q load + using PipelineLoadQK = cute::conditional_t, PipelineTmaUmmaAsync>; + using PipelineLoadPV = PipelineLoadQK; + // pipeline from mma (Q@K) to softmax, PipelineUmmaAsync, 2 stages + using PipelineS = PipelineUmmaAsync; + // pipeline from softmax (P) to mma (bmm2), PipelineUmmaAsync, 2 stages + using PipelineP = PipelineUmmaConsumerAsync; + // pipeline from mma to softmax (for rescale), PipelineUmmaAsync, 1 stage + using PipelineO = PipelineUmmaAsync<1, AtomThrShapeMNK>; + + using PipelinePT = PipelineAsync; + + struct PipelineStorage { + alignas(16) typename PipelineLoadQK::SharedStorage load_qk; + alignas(16) typename PipelineS::SharedStorage mma_s; + alignas(16) typename PipelineP::SharedStorage p_mma; + alignas(16) typename PipelineO::SharedStorage mma_o; + alignas(16) typename PipelinePT::SharedStorage load_page_table; + }; + + template + static CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, _, make_layout(stages))); + } + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutKC = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutVC = typename CollectiveMmaPV::SmemLayoutB; + using SmemLayoutP = decltype(unstageSmemLayout(typename CollectiveMmaPV::SmemLayoutA{}, make_shape(Int{}, _2{}))); + using SmemLayoutOut = decltype(take<0,2>(typename CollectiveMmaQK::CtaShape_MNK{})); + using TileShapeAcc = decltype(take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{})); + + static const int kBytesLoadQ = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutQ{})) * cute::sizeof_bits_v); + static const int kBytesLoadKC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutKC{})) * cute::sizeof_bits_v); + static const int kBytesLoadVC = size(AtomThrShapeMNK{}) * cutlass::bits_to_bytes(cosize(take<0,3>(SmemLayoutVC{})) * cute::sizeof_bits_v); + + // pre-condition for overlapped smem staging + static_assert(kBytesLoadKC == kBytesLoadVC); + static_assert(StagesQK == StagesPV); + + static const int kTransactionsBytesLoadQK = kBytesLoadKC; + static const int kTransactionsBytesLoadExtraQ = kBytesLoadQ; + static const int kTransactionsBytesLoadPV = kBytesLoadVC; + + static const int kNamedBarrierExchange = (int) cutlass::arch::ReservedNamedBarriers::TransformBarrier; + // This Named Barrier is introduced to solve Q tile loading overwritten issue when enable persistent + // tile scheduler for FP8 MLA. + static const int kNamedBarrierEpilogue = (int) cutlass::arch::ReservedNamedBarriers::EpilogueBarrier; + // + static const int kNamedBarrierTmemDealloc = (int) cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier; + + enum class TmemAllocation : uint32_t { + kSizeS = TileShapeS::value / kWarpsInN, + // Overall + kSizeO = TileShapeL::value / kWarpsInN, + // Between accumulators we loop over + kSizeAccO = decltype(get<1>(TileShapePV{}))::value / kWarpsInN, + kNumS = TotalSNum, + kNumP = TotalPNum, + kNumO = 1, + kS0 = 0, + kS1 = kS0 + kSizeS, + kO0 = kS1 + kSizeS, + kTotal = kO0 + kSizeO + }; + + static_assert(static_cast(TmemAllocation::kTotal) <= TmemAllocator::Sm100TmemCapacityColumns, "using too much tmem"); + + struct TensorStorage { + // to communicate max and row_sum + cute::array smem_exchange; + cute::array smem_page_table; + alignas(2048) cute::array> smem_q; + union { + alignas(2048) cute::array> smem_kc; + alignas(2048) cute::array> smem_vc; + }; + union { + alignas(2048) cute::array> smem_p; + alignas(2048) cute::array smem_acc; + }; + }; + + struct SharedStorage { + PipelineStorage pipelines; + TensorStorage tensors; + uint32_t tmem_base_ptr; + }; + + static const int SharedStorageSize = sizeof(SharedStorage); + static_assert(SharedStorageSize <= cutlass::arch::sm100_smem_capacity_bytes, "using too much smem"); + + struct MainloopArguments { + ElementAcc softmax_scale; + + // all tensors strides are (num_heads or seqlen, head_dim, batch) + // head_dim stride is always 1 + Element* ptr_q_latent; + TensorStride stride_q_latent; + Element* ptr_q_rope; + TensorStride stride_q_rope; + + Element* ptr_c_latent; + TensorStride stride_c_latent; + Element* ptr_k_rope; + TensorStride stride_k_rope; + + // for paged attention, we interpret what was previously [batch, seqlen] + // as [page_count, page_size], and index according to page_table + int* ptr_seq = nullptr; + int* ptr_page_table = nullptr; + // page table is [batch, seqlen or similar] + Stride<_1, int> stride_page_table = {}; + int page_count = 0; + int page_size = TileShapeS{}; // powers of two if kIsCpAsync, otherwise TileShapeS + }; + + struct EpilogueArguments { + ElementOut* ptr_o = nullptr; + TensorStride stride_o; + ElementLSE* ptr_lse = nullptr; + Stride<_1, int> stride_lse; + ElementAcc output_scale = 1.0f; + }; + + struct Arguments { + // (num_heads=128, seqlen, (d_latent=512, d_rope=64), batch_count) + // for paged attention, seqlen is max seqlen + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueArguments epilogue; + KernelHardwareInfo hw_info; + int split_kv = -1; + int* ptr_split_kv = nullptr; + bool is_fused_reduction = false; + }; + + using TmaLoadQLatent = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadQRope = typename CollectiveMmaQK::Params::TMA_A; + using TmaLoadCLatent = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadKRope = typename CollectiveMmaQK::Params::TMA_B; + using TmaLoadCLatentTranspose = typename CollectiveMmaPV::Params::TMA_B; + + using GmemLayout = decltype(make_layout(Shape{}, Stride{})); + using SmemLayout = decltype(make_layout(TileShapeAcc{}, LayoutRight{})); + + using TmaReduceSum = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, + make_tensor(recast_ptr(nullptr), GmemLayout{}), SmemLayout{})); + + struct MainloopParams { + TmaLoadQLatent tma_load_q_latent; + TmaLoadQRope tma_load_q_rope; + TmaLoadCLatent tma_load_c_latent; + TmaLoadKRope tma_load_k_rope; + TmaLoadCLatentTranspose tma_load_c_latent_transpose; + }; + + struct EpilogueParams { + ElementOut* ptr_o = nullptr; + ElementAcc* ptr_o_acc = nullptr; + TensorStride stride_o; + TensorStride stride_o_acc; + ElementLSE* ptr_lse = nullptr; + ElementLSE* ptr_lse_acc = nullptr; + Stride<_1, int> stride_lse; + Stride<_1, int> stride_lse_acc; + ElementAcc output_scale = 1.0f; + ElementLSE* ptr_lse_exchange_buff = nullptr; + int* ptr_lse_max_exchange_buff = nullptr; + int* ptr_lock = nullptr; // semaphore + TmaReduceSum tma_reduce_sum; + }; + + struct Params { + ProblemShape problem_shape; + MainloopArguments mainloop; + EpilogueParams epilogue; + MainloopParams mainloop_params; + typename TileScheduler::Params tile_scheduler; + int split_kv = -1; + int* ptr_split_kv = nullptr; + bool is_fused_reduction = false; + }; + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + //workspace = nullptr; // let's get an error if one of these needs workspace + + auto [H, K, D, B] = args.problem_shape; + auto [L, R] = D; + + int paged_B = B; + int paged_K = K; + if (args.mainloop.ptr_page_table != nullptr) { + paged_B = args.mainloop.page_count; + paged_K = args.mainloop.page_size; + } + + auto params_qk_latent = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, L, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_latent_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, L, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, + args.mainloop.ptr_c_latent, args.mainloop.stride_c_latent, + }, nullptr); + + auto params_qk_rope = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, K, R, B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + auto params_qk_rope_paged = CollectiveMmaQK::to_underlying_arguments( + make_shape(H, paged_K, R, paged_B), + typename CollectiveMmaQK::Arguments { + args.mainloop.ptr_q_rope, args.mainloop.stride_q_rope, + args.mainloop.ptr_k_rope, args.mainloop.stride_k_rope, + }, nullptr); + + + auto stride_c_latent_transpose = select<1,0,2>(args.mainloop.stride_c_latent); + auto params_pv_latent = CollectiveMmaPV::to_underlying_arguments( + make_shape(H, L, paged_K, paged_B), + typename CollectiveMmaPV::Arguments { + args.mainloop.ptr_q_latent, args.mainloop.stride_q_latent, // dummy, never used + args.mainloop.ptr_c_latent, stride_c_latent_transpose, + }, nullptr); + + MainloopParams mainloop_params { + params_qk_latent.tma_load_a, + params_qk_rope.tma_load_a, + params_qk_latent_paged.tma_load_b, + params_qk_rope_paged.tma_load_b, + params_pv_latent.tma_load_b + }; + + EpilogueParams epilogue_params; + + epilogue_params.ptr_o = args.epilogue.ptr_o; + epilogue_params.stride_o = args.epilogue.stride_o; + epilogue_params.ptr_lse = args.epilogue.ptr_lse; + epilogue_params.stride_lse = args.epilogue.stride_lse; + epilogue_params.output_scale = args.epilogue.output_scale; + epilogue_params.tma_reduce_sum = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(recast_ptr(args.epilogue.ptr_o), make_layout(make_shape(H, L, B), args.epilogue.stride_o)), SmemLayout{}); + + if (!args.is_fused_reduction && args.split_kv > 1) { + ElementAcc* ptr_o_acc = reinterpret_cast(workspace); + ElementLSE* ptr_lse_acc = reinterpret_cast(ptr_o_acc + H * L * args.split_kv * B); + epilogue_params.ptr_o_acc = ptr_o_acc; + epilogue_params.ptr_lse_acc = ptr_lse_acc; + + epilogue_params.stride_o_acc = make_tuple(static_cast(0 + L) * args.split_kv, _1{}, static_cast(0 + H * L) * args.split_kv); + epilogue_params.stride_lse_acc = make_tuple(_1{}, (0 + H) * args.split_kv); + } else if (args.is_fused_reduction && args.split_kv > 1) { + ElementLSE* ptr_lse_exchange_buff = reinterpret_cast(workspace); + epilogue_params.ptr_lse_exchange_buff = ptr_lse_exchange_buff; + int* ptr_lse_max_exchange_buff = reinterpret_cast(ptr_lse_exchange_buff + H * B); + epilogue_params.ptr_lse_max_exchange_buff = ptr_lse_max_exchange_buff; + int* ptr_lock = ptr_lse_max_exchange_buff + H * B; + epilogue_params.ptr_lock = ptr_lock; + } + + return {args.problem_shape, args.mainloop, epilogue_params, mainloop_params, + TileScheduler::to_underlying_arguments(args.problem_shape, args.hw_info, ClusterShape{}, args.split_kv), + args.split_kv, args.ptr_split_kv, args.is_fused_reduction}; + } + + static size_t get_workspace_size(Arguments const& args) { + ProblemShape problem_shape = args.problem_shape; + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto split_kv = args.split_kv; + size_t workspace_size {0}; + if (args.is_fused_reduction && args.split_kv > 1) { + // one exchange buffer for LSE max and another buffer for total LSE + // two locks per batch, frist lock is for CTA0 / H=0..63 and the second is for CTA1 / H=64..127 + workspace_size = H * B * (sizeof(int) + sizeof(ElementLSE)) + 2 * B * sizeof(int); + } else if (!args.is_fused_reduction && args.split_kv > 1) { + workspace_size = (sizeof(ElementAcc) * D_latent + sizeof(ElementLSE)) * H * split_kv * B; + } + return workspace_size; + } + static Status initialize_workspace( + Arguments const& args, void* ws, cudaStream_t stream) { + auto workspace_size = get_workspace_size(args); + if (args.is_fused_reduction && args.split_kv > 1) { + auto result = cudaMemsetAsync(ws, 0, workspace_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaMemsetAsync() returned error: " + << cudaGetErrorString(result)); + return Status::kErrorInternal;; + } + } + return Status::kSuccess; + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static bool can_implement(Arguments const& args) { + if (kIsCpAsync) { + if ((args.mainloop.page_size & (args.mainloop.page_size - 1)) != 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size pow2\n"; + return false; + } + if (args.mainloop.page_size > TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): cpasync page size too big\n"; + return false; + } + } + else { + if (args.mainloop.ptr_page_table != nullptr && args.mainloop.page_size != TileShapeS{}) { + std::cerr << __FILE__ << "(" << __LINE__ << "): tma page size off\n"; + return false; + } + } + if (get<0>(args.problem_shape) != 128) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; + return false; + } + if (get<1>(args.problem_shape) <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): heads off\n"; + return false; + } + if (args.split_kv <= 0) { + std::cerr << __FILE__ << "(" << __LINE__ << "): split-k off\n"; + return false; + } + if (args.is_fused_reduction && args.split_kv > 1) { + if (2 * args.split_kv > args.hw_info.sm_count || + std::is_same_v) { + return false; + } + } + return true; + } + + + CUTLASS_DEVICE void operator()(Params const& params, char* smem_raw) { +#if (! defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) && ! defined(CUTLASS_ARCH_MMA_SM100F_ENABLED)) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + TileScheduler tile_scheduler(params.tile_scheduler); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + auto role = warp_idx_to_role(warp_idx); + uint32_t lane_predicate = cute::elect_one_sync(); + + uint32_t cta_rank_in_cluster = cute::block_rank_in_cluster(); + int cta_coord_v = cta_rank_in_cluster % size<0>(AtomThrShapeMNK{}); + bool is_mma_leader_cta = cta_coord_v == 0; + + if (role == WarpRole::kLoad && lane_predicate && ! kIsCpAsync) { + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_q_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_k_rope.get_tma_descriptor()); + prefetch_tma_descriptor(params.mainloop_params.tma_load_c_latent_transpose.get_tma_descriptor()); + } + SharedStorage& shared_storage = *reinterpret_cast(smem_raw); + + typename PipelineLoadQK::Params pipeline_load_qk_params; + if (role == WarpRole::kLoad) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Producer; + } + if (role == WarpRole::kMma) { + pipeline_load_qk_params.role = PipelineLoadQK::ThreadCategory::Consumer; + } + if constexpr (kIsCpAsync) { + // we can make our life easier by unconditionally loading blocks + // since we know it'll always be legal + pipeline_load_qk_params.producer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + } + else { + pipeline_load_qk_params.is_leader = lane_predicate && (role == WarpRole::kLoad) && is_mma_leader_cta; + pipeline_load_qk_params.transaction_bytes = kTransactionsBytesLoadQK; + } + pipeline_load_qk_params.initializing_warp = 0; + PipelineLoadQK pipeline_load_qk(shared_storage.pipelines.load_qk, pipeline_load_qk_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineS::Params pipeline_mma_s_params; + if (role == WarpRole::kMma) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_s_params.role = PipelineS::ThreadCategory::Consumer; + } + pipeline_mma_s_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_s_params.initializing_warp = 1; + PipelineS pipeline_mma_s( + shared_storage.pipelines.mma_s, + pipeline_mma_s_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineP::Params pipeline_p_mma_params; + if (role == WarpRole::kMma) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Consumer; + } + if (role == WarpRole::kCompute) { + pipeline_p_mma_params.role = PipelineP::ThreadCategory::Producer; + } + pipeline_p_mma_params.producer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_p_mma_params.consumer_arv_count = 1; + pipeline_p_mma_params.initializing_warp = 2; + PipelineP pipeline_p_mma( + shared_storage.pipelines.p_mma, + pipeline_p_mma_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelineO::Params pipeline_mma_o_params; + if (role == WarpRole::kMma) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Producer; + } + if (role == WarpRole::kCompute) { + pipeline_mma_o_params.role = PipelineO::ThreadCategory::Consumer; + } + pipeline_mma_o_params.consumer_arv_count = kNumComputeWarps * cutlass::NumThreadsPerWarp * size(AtomThrShapeMNK{}); + pipeline_mma_o_params.initializing_warp = 3; + PipelineO pipeline_mma_o( + shared_storage.pipelines.mma_o, + pipeline_mma_o_params, + ClusterShape{}, /*barrier init*/ cute::true_type{}, /*mask calc*/cute::false_type{}); + + typename PipelinePT::Params pipeline_pt_params; + if (role == WarpRole::kLoad) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Consumer; + } + if (role == WarpRole::kLoadPageTable) { + pipeline_pt_params.role = PipelinePT::ThreadCategory::Producer; + } + pipeline_pt_params.consumer_arv_count = kNumLoadWarps * cutlass::NumThreadsPerWarp; + pipeline_pt_params.producer_arv_count = cutlass::NumThreadsPerWarp; + pipeline_pt_params.initializing_warp = 4; + PipelinePT pipeline_page_table( + shared_storage.pipelines.load_page_table, + pipeline_pt_params); + + TmemAllocator tmem_allocator; + + pipeline_init_arrive_relaxed(size(ClusterShape{})); + + pipeline_load_qk.init_masks(ClusterShape{}); // do we need an update here for 2Sm? + pipeline_mma_s.init_masks(ClusterShape{}); + pipeline_p_mma.init_masks(ClusterShape{}); + pipeline_mma_o.init_masks(ClusterShape{}); + + typename PipelineLoadQK::PipelineState pipeline_load_qk_consumer_state; + typename PipelineLoadQK::PipelineState pipeline_load_qk_producer_state = cutlass::make_producer_start_state(); + + typename PipelineS::PipelineState pipeline_mma_s_consumer_state; + typename PipelineS::PipelineState pipeline_mma_s_producer_state = cutlass::make_producer_start_state(); + + typename PipelineP::PipelineState pipeline_p_mma_consumer_state; + typename PipelineP::PipelineState pipeline_p_mma_producer_state = cutlass::make_producer_start_state(); + + typename PipelineO::PipelineState pipeline_mma_o_consumer_state; + typename PipelineO::PipelineState pipeline_mma_o_producer_state = cutlass::make_producer_start_state(); + + typename PipelinePT::PipelineState pipeline_pt_consumer_state; + typename PipelinePT::PipelineState pipeline_pt_producer_state = cutlass::make_producer_start_state(); + + pipeline_init_wait(size(ClusterShape{})); + + if (role == WarpRole::kLoadPageTable) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_page_table( + blk_coord, + problem_shape, + params.mainloop, + shared_storage.tensors, + pipeline_page_table, pipeline_pt_producer_state, + local_split_kv + ); + } + } + else if (role == WarpRole::kLoad) { + if constexpr (kIsCpAsync) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_cpasync( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv, + /* must be shared pipe */ + pipeline_page_table, pipeline_pt_consumer_state + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + if (params.mainloop.ptr_page_table != nullptr) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + else { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + load_tma( + blk_coord, + problem_shape, + params.mainloop, + params.mainloop_params, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_producer_state, + pipeline_load_qk, pipeline_load_qk_producer_state, + local_split_kv + ); + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive_and_wait(); + } + } + } + } + else if (role == WarpRole::kMma) { + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + + if (is_mma_leader_cta) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto local_split_kv = params.split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + mma(blk_coord, + problem_shape, + shared_storage.tensors, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_load_qk, pipeline_load_qk_consumer_state, + pipeline_mma_s, pipeline_mma_s_producer_state, + pipeline_p_mma, pipeline_p_mma_consumer_state, + pipeline_mma_o, pipeline_mma_o_producer_state, + local_split_kv + ); + } + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive_and_wait(); + + //uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + //tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } + else if (role == WarpRole::kCompute) { + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto problem_shape = params.problem_shape; + auto split_kv = params.split_kv; + auto local_split_kv = split_kv; + if (params.mainloop.ptr_seq != nullptr) { + get<1>(problem_shape) = params.mainloop.ptr_seq[get<2>(blk_coord)]; + if (params.ptr_split_kv != nullptr) { + local_split_kv = params.ptr_split_kv[get<2>(blk_coord)]; + } + } + if (local_split_kv <= get<3>(blk_coord)) + continue; + compute( + blk_coord, + problem_shape, + params.mainloop, // for softmax_scale + params.epilogue, + shared_storage.tensors, // for smem_comm + pipeline_mma_s, pipeline_mma_s_consumer_state, + pipeline_p_mma, pipeline_p_mma_producer_state, + pipeline_mma_o, pipeline_mma_o_consumer_state, + local_split_kv, + params.is_fused_reduction + ); + } + + //cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + } + + cute::cluster_sync(); + cutlass::arch::NamedBarrier((kNumComputeWarps + 1) * NumThreadsPerWarp, kNamedBarrierTmemDealloc).arrive(); + if (role == WarpRole::kMma) { + uint32_t free_stage_ptr = shared_storage.tmem_base_ptr; + tmem_allocator.free(free_stage_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } +#endif + } + + template + CUTLASS_DEVICE void load_page_table( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_producer_state, int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + int batch_coord = get<2>(blk_coord); + + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), + make_shape(mainloop_args.page_count, B), + mainloop_args.stride_page_table); + auto mPT = mPT_l(_, batch_coord); + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + auto page_size = Pow2{mainloop_args.page_size}; + auto pages_per_tile = Pow2{TileShapeS{} / page_size}; + int thread_idx = threadIdx.x % cutlass::NumThreadsPerWarp; + + for (; k_tile_count > 0; ++k_index, --k_tile_count) { + pipeline_page_table.producer_acquire(pipeline_pt_producer_state); + + // assume a single warp + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TileShapeS{}; i += cutlass::NumThreadsPerWarp) { + int idx = i + thread_idx; + bool guard = idx < pages_per_tile; + int smem_idx = pipeline_pt_producer_state.index() * TileShapeS::value + idx; + int pt_idx = pages_per_tile * k_index + idx; + + cutlass::arch::cp_async_zfill( + &shared_tensors.smem_page_table[smem_idx], &mPT(pt_idx), guard + ); + } + + pipeline_page_table.producer_commit(pipeline_pt_producer_state, cutlass::arch::cpasync_barrier_arrive); + ++pipeline_pt_producer_state; + } + } + + + struct Gather { + int& page_table_stage; + Pow2 pages_per_tile; + const int * __restrict__ smem_page_table; + + CUTLASS_DEVICE int operator()(int idx) const { + return smem_page_table[page_table_stage * TileShapeS::value + idx % pages_per_tile]; + } + + CUTLASS_DEVICE friend void print(Gather const&) { + printf(""); + } + + }; + + + template + CUTLASS_DEVICE void load_cpasync( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load, + typename PipelineLoadQK::PipelineState& pipeline_load_producer_state, + int const& split_kv, + PipelinePT& pipeline_page_table, + typename PipelinePT::PipelineState& pipeline_pt_consumer_state) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using X = Underscore; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // partition all tensors + auto mQL = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_latent), make_shape(H, D_latent, B), mainloop_args.stride_q_latent); + auto mQR = make_tensor(make_gmem_ptr(mainloop_args.ptr_q_rope), make_shape(H, D_rope, B), mainloop_args.stride_q_rope); + + int paged_B = mainloop_args.page_count; + auto paged_K = Pow2{mainloop_args.page_size}; + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + int batch_coord = get<2>(blk_coord); + auto mPT = mPT_l(_, batch_coord); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto make_copy_for = [](auto sT) { + auto rT_a = sT.layout()(_, _, _, _0{}); + auto rT = make_ordered_layout(shape(rT_a), stride(rT_a)); + auto threads = Int{}; + auto values = Int{}; + return make_cotiled_copy( + Copy_Atom, Element>{}, + make_ordered_layout( + make_shape(threads, values), + make_stride(_1{}, _0{})), + rT); + }; + + // like cute::copy, but makes sure we do all page table lookups first + auto copy_split = [](auto atom, auto src, auto dst) { + auto src_v = group_modes<1, rank_v>(src); + auto dst_v = group_modes<1, rank_v>(dst); + + auto src_v_ptrs = make_tensor(size<1>(src_v)); + for (int i = 0; i < size<1>(src_v); i++) { + src_v_ptrs(i) = &src_v(_0{}, i); + } + + + for (int i = 0; i < size<1>(src_v); i++) { + auto src_v_i = make_tensor( + make_gmem_ptr(src_v_ptrs(i)), + make_shape(shape<0>(src_v)), + make_stride(make_stride(_1{}, _0{})) + ); + atom.call(src_v_i, dst_v(_, i)); + } + }; + + auto tiled_copy_q = make_copy_for(sQ); + auto tiled_copy_kc = make_copy_for(sKC); + auto tiled_copy_vc = make_copy_for(sVC); + + auto thr_copy_q = tiled_copy_q.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_kc = tiled_copy_kc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + auto thr_copy_vc = tiled_copy_vc.get_thread_slice(threadIdx.x % (kNumLoadWarps * cutlass::NumThreadsPerWarp)); + + auto tQsQ = thr_copy_q.partition_D(sQ); + auto tQgQL = thr_copy_q.partition_S(tSgQL); + auto tQgQR = thr_copy_q.partition_S(tSgQR); + + auto tKCsKC = thr_copy_kc.partition_D(sKC); + auto tVCsVC = thr_copy_vc.partition_D(sVC); + + auto pipeline_pt_release_state = pipeline_pt_consumer_state; + + int page_table_stage = -1; + Pow2 pages_per_tile{TileShapeS{} / paged_K}; + const int * __restrict__ smem_page_table = shared_tensors.smem_page_table.begin(); + Gather gather{page_table_stage, pages_per_tile, smem_page_table}; + + auto mCL = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))), get<1>(mainloop_args.stride_c_latent))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mKR = make_tensor( + make_gmem_ptr(mainloop_args.ptr_k_rope), + ComposedLayout{ + make_layout( + make_shape(make_shape(paged_K, paged_B), _1{}), + make_stride(make_stride(get<0>(mainloop_args.stride_k_rope), example::CustomStride(gather, get<2>(mainloop_args.stride_k_rope))), get<1>(mainloop_args.stride_k_rope))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(paged_K * paged_B, D_latent))}); + + auto mCLT = make_tensor( + make_gmem_ptr(mainloop_args.ptr_c_latent), + ComposedLayout{ + make_layout( + make_shape(_1{}, make_shape(paged_K, paged_B)), + make_stride(get<1>(mainloop_args.stride_c_latent), make_stride(get<0>(mainloop_args.stride_c_latent), example::CustomStride(gather, get<2>(mainloop_args.stride_c_latent))))), + make_coord(_0{}, _0{}), + make_identity_layout(make_shape(D_latent, paged_K * paged_B))}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + auto tKCgCL = thr_copy_kc.partition_S(tSgCL); + auto tKCgKR = thr_copy_kc.partition_S(tSgKR); + auto tVCgCLT = thr_copy_vc.partition_S(tOgCLT); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + auto& pipeline_acquire_state = pipeline_load_producer_state; + auto pipeline_commit_state = pipeline_acquire_state; + int pipeline_offset = 0; + + for (int i = 0; i < StagesPV; i++) { + cutlass::arch::cp_async_fence(); + } + + auto load_stage = [&](auto fn) { + pipeline_load.producer_acquire(pipeline_acquire_state); + fn(pipeline_acquire_state.index()); + cutlass::arch::cp_async_fence(); + + ++pipeline_acquire_state; + ++pipeline_offset; + + if (pipeline_offset == StagesPV - 1) { + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + }; + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQL(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, i)); + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + cute::copy(tiled_copy_q, tQgQR(_, _, _, _, _0{}, i, batch_coord), tQsQ(_, _, _, _, IterationsQKLatent + i)); + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_page_table.consumer_wait(pipeline_pt_consumer_state); + page_table_stage = pipeline_pt_consumer_state.index(); + ++pipeline_pt_consumer_state; + + for (int i = 0; i < IterationsQKLatent; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgCL(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + for (int i = 0; i < IterationsQKRope; i++) { + load_stage([&](int index) { + copy_split(tiled_copy_kc, tKCgKR(_, _, _, _, k_index, i), tKCsKC(_, _, _, _, index)); + }); + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + k_index += 1; + k_tile_count -= 1; + } + + page_table_stage = pipeline_pt_release_state.index(); + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + load_stage([&](int index) { + copy_split(tiled_copy_vc, tVCgCLT(_, _, _, _, j, IterationsPV_K * (k_index - 1) + i), tVCsVC(_, _, _, _, index)); + }); + } + } + + pipeline_page_table.consumer_release(pipeline_pt_release_state); + ++pipeline_pt_release_state; + + while (pipeline_offset > 0) { + cutlass::arch::cp_async_fence(); + + cutlass::arch::cp_async_wait(); + pipeline_load.producer_commit(pipeline_commit_state); + ++pipeline_commit_state; + --pipeline_offset; + } + + cutlass::arch::cp_async_wait<0>(); + + } + + + template + CUTLASS_DEVICE void load_tma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + MainloopParams const& mainloop_params, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_producer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + using X = Underscore; + + // partition all tensors + auto mQL = mainloop_params.tma_load_q_latent.get_tma_tensor(make_shape(H, D_latent, B)); + auto mQR = mainloop_params.tma_load_q_rope.get_tma_tensor(make_shape(H, D_rope, B)); + + int paged_B = B; + int paged_K = K; + if constexpr (kIsPaged) { + paged_B = mainloop_args.page_count; + paged_K = mainloop_args.page_size; + } + auto mPT_l = make_tensor(make_gmem_ptr(mainloop_args.ptr_page_table), make_shape(paged_B, B), mainloop_args.stride_page_table); + + auto mCL = mainloop_params.tma_load_c_latent.get_tma_tensor(make_shape(paged_K, D_latent, paged_B)); + auto mKR = mainloop_params.tma_load_k_rope.get_tma_tensor(make_shape(paged_K, D_rope, paged_B)); + + auto mCLT = mainloop_params.tma_load_c_latent_transpose.get_tma_tensor(make_shape(D_latent, paged_K, paged_B)); + + auto gQL = local_tile(mQL, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + auto gQR = local_tile(mQR, TileShapeQK{}, make_coord(_,_,_), Step<_1, X, _1>{}); + + auto gCL = local_tile(mCL, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gKR = local_tile(mKR, TileShapeQK{}, make_coord(_,_,_), Step{}); + auto gCLT = local_tile(mCLT, TileShapePV{}, make_coord(_,_,_), Step{}); + + ThrMMA cta_mma_qk = TiledMmaQK{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + ThrMMA cta_mma_pv = TiledMmaPV{}.get_slice(get<0>(blk_coord) % size(AtomThrShapeMNK{})); + + auto tSgQL = cta_mma_qk.partition_A(gQL); + auto tSgQR = cta_mma_qk.partition_A(gQR); + + auto tSgCL = cta_mma_qk.partition_B(gCL); + auto tSgKR = cta_mma_qk.partition_B(gKR); + + auto tOgCLT = cta_mma_pv.partition_B(gCLT); + + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + + auto [tQLgQL_mkl, tQsQ] = tma_partition( + mainloop_params.tma_load_q_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQL)); + + auto [tQRgQR_mkl, tQsQ_ignore] = tma_partition( + mainloop_params.tma_load_q_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sQ), group_modes<0,3>(tSgQR)); + + auto [tCLgCL_nkl, tKCsKC] = tma_partition( + mainloop_params.tma_load_c_latent, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgCL)); + + auto [tKRgKR_nkl, tKCsKC_ignore] = tma_partition( + mainloop_params.tma_load_k_rope, _0{}, make_layout(_1{}), + group_modes<0,3>(sKC), group_modes<0,3>(tSgKR)); + + auto [tCLTgCLT_nkl, tVCsVC] = tma_partition( + mainloop_params.tma_load_c_latent_transpose, _0{}, make_layout(_1{}), + group_modes<0,3>(sVC), group_modes<0,3>(tOgCLT)); + + uint16_t mcast_mask = 0; + + int batch_coord = get<2>(blk_coord); + Tensor tQLgQL = tQLgQL_mkl(_, _, _, batch_coord); + Tensor tQRgQR = tQRgQR_mkl(_, _, _, batch_coord); + + auto mPT = mPT_l(_, batch_coord); + + Tensor tCLgCL = tCLgCL_nkl(_, _, _, _); + Tensor tKRgKR = tKRgKR_nkl(_, _, _, _); + + // careful: stage and k are swapped here! + Tensor tCLTgCLT = tCLTgCLT_nkl(_, _, _, _); + + // latent is first in memory, so let's load it first always + // startup: alternate Q and K, set tx count appropriately, for k_idx = 0 + + // each Q/K tile consists of rope and latent + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_latent.with(*tma_barrier, mcast_mask), tQLgQL(_, _0{}, i), tQsQ(_, i)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_expect_transaction(pipeline_load_qk_producer_state, kTransactionsBytesLoadExtraQ); + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // expect the extra bytes + // load_qk ql + cute::copy(mainloop_params.tma_load_q_rope.with(*tma_barrier, mcast_mask), tQRgQR(_, _0{}, i), tQsQ(_, i + IterationsQKLatent)); + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + k_index += 1; + k_tile_count -= 1; + + // assume k_tile_count >= 1 + // perform K+Q load here + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + // perform K load + for (int i = 0; i < IterationsQKLatent; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent.with(*tma_barrier, mcast_mask), + tCLgCL(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + for (int i = 0; i < IterationsQKRope; i++) { + pipeline_load_qk.producer_acquire(pipeline_load_qk_producer_state); + auto tma_barrier = pipeline_load_qk.producer_get_barrier(pipeline_load_qk_producer_state); + + if (cute::elect_one_sync()) { + // load_qk cl + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, _0{}, i, mPT(k_index)), + tKCsKC(_, pipeline_load_qk_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_k_rope.with(*tma_barrier, mcast_mask), + tKRgKR(_, k_index, i, batch_coord), + tKCsKC(_, pipeline_load_qk_producer_state.index())); + } + } + ++pipeline_load_qk_producer_state; + } + + // prefetch next K load to keep busy while we transpose-load from cache + const int kPrefetchDistance = 1; + for (int i = 0; i < IterationsQKLatent; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_c_latent, + tCLgCL(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + for (int i = 0; i < IterationsQKRope; i++) { + if (cute::elect_one_sync()) { + if constexpr (kIsPaged) { + if (k_tile_count > kPrefetchDistance) { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, _0{}, i, mPT(k_index + kPrefetchDistance)) + ); + } + } + else { + cute::prefetch( + mainloop_params.tma_load_k_rope, + tKRgKR(_, k_index + kPrefetchDistance, i, batch_coord) + ); + } + } + } + + // perform V load (k_idx - 1) + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices! + // note we are off-by-one on k_index + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + + k_index += 1; + k_tile_count -= 1; + } + + for (int i = 0; i < IterationsPV_K; i++) { + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.producer_acquire(pipeline_load_pv_producer_state); + auto tma_barrier = pipeline_load_pv.producer_get_barrier(pipeline_load_pv_producer_state); + + if (cute::elect_one_sync()) { + // load_pv cl + // note the transpose in indices + // note we are off-by-one on k_index + + if constexpr (kIsPaged) { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, i, mPT(k_index - 1)), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + else { + cute::copy( + mainloop_params.tma_load_c_latent_transpose.with(*tma_barrier, mcast_mask, cute::TMA::CacheHintSm100::EVICT_FIRST), + tCLTgCLT(_, j, IterationsPV_K * (k_index - 1) + i, batch_coord), + tVCsVC(_, pipeline_load_pv_producer_state.index()) + ); + } + } + ++pipeline_load_pv_producer_state; + } + } + } + + template + CUTLASS_DEVICE void mma( + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + TensorStorage& shared_tensors, + PipelineLoadQK& pipeline_load_qk, + typename PipelineLoadQK::PipelineState& pipeline_load_qk_consumer_state, + PipelineLoadPV& pipeline_load_pv, + typename PipelineLoadPV::PipelineState& pipeline_load_pv_consumer_state, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_producer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_consumer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_producer_state, + int const& split_kv) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(blk_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + return; + } + + // mma init + Tensor sQ = make_tensor(make_smem_ptr(shared_tensors.smem_q.begin()), SmemLayoutQ{}); + Tensor sKC = make_tensor(make_smem_ptr(shared_tensors.smem_kc.begin()), SmemLayoutKC{}); + Tensor sVC = make_tensor(make_smem_ptr(shared_tensors.smem_vc.begin()), SmemLayoutVC{}); + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{}); + + Tensor tSrQ = TiledMmaQK::make_fragment_A(sQ); + Tensor tSrKC = TiledMmaQK::make_fragment_B(sKC); + Tensor tOrP = TiledMmaPV::make_fragment_A(sP); + Tensor tOrVC = TiledMmaPV::make_fragment_B(sVC); + + TiledMmaQK tiled_mma_qk; + TiledMmaPV tiled_mma_pv; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::Zero; + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + + // Mma S0 S1 O0 S2 O1 ... Sn On-1 On + // S0 ownership -- ----- -- -- + // S1 ownership -- ----- ---- + // O ownership -- -- ---- -- + + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + + pipeline_mma_s.producer_acquire(pipeline_mma_s_producer_state); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::Zero; + for (int i = 0; i < IterationsQK; i++) { + pipeline_load_qk.consumer_wait(pipeline_load_qk_consumer_state); + int read_stage = pipeline_load_qk_consumer_state.index(); + + tStS.data() = uint32_t(pipeline_mma_s_producer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1); + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tSrQ); ++k_block) { + cute::gemm(tiled_mma_qk, + tSrQ(_,_,k_block,i), + tSrKC(_,_,k_block,read_stage), + tStS); + tiled_mma_qk.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_qk.consumer_release(pipeline_load_qk_consumer_state); + ++pipeline_load_qk_consumer_state; + } + + pipeline_mma_s.producer_commit(pipeline_mma_s_producer_state); + ++pipeline_mma_s_producer_state; + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + + --k_tile_count; + } + + pipeline_mma_o.producer_acquire(pipeline_mma_o_producer_state); + pipeline_p_mma.consumer_wait(pipeline_p_mma_consumer_state); + + for (int i = 0; i < IterationsPV_K; i++) { + auto acc_flag = tiled_mma_pv.accumulate_; + for (int j = 0; j < IterationsPV_N; j++) { + pipeline_load_pv.consumer_wait(pipeline_load_pv_consumer_state); + + int read_stage = pipeline_load_pv_consumer_state.index(); + + tOtO.data() = uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO); + tiled_mma_pv.accumulate_ = acc_flag; + + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tOrP); ++k_block) { + cute::gemm(tiled_mma_pv, + tOrP(_,_,k_block, make_coord(i, pipeline_p_mma_consumer_state.index())), + tOrVC(_,_,k_block,read_stage), + tOtO); + tiled_mma_pv.accumulate_ = UMMA::ScaleOut::One; + } + + pipeline_load_pv.consumer_release(pipeline_load_pv_consumer_state); + ++pipeline_load_pv_consumer_state; + } + } + + pipeline_p_mma.consumer_release(pipeline_p_mma_consumer_state); + ++pipeline_p_mma_consumer_state; + pipeline_mma_o.producer_commit(pipeline_mma_o_producer_state); + ++pipeline_mma_o_producer_state; + } + + + template + CUTLASS_DEVICE void softmax( + IsLastTile const& is_last_tile, + ElementAcc& row_max, + ElementAcc& row_sum, + ElementAcc& correction_factor, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + TensorStorage& shared_tensors, + int k_index, + uint32_t tmem_s, + int smem_p_index) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaQK tiled_mma_qk; + + Tensor tStS = partition_fragment_C(tiled_mma_qk, select<0,1>(TileShapeQK{})); + tStS.data() = tmem_s; + + CUTE_STATIC_ASSERT_V(shape<1>(tStS) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tStS) == _1{}); + Tensor tAcc = tStS(make_coord(_,_),_0{},_0{}); + + Tensor cS = make_identity_tensor(take<0,2>(CtaShapeQK{})); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_cS = thread_t2r.partition_D(cS); + Tensor tTR_rAcc = make_tensor(shape(tTR_cS)); + + Tensor tTR_rS_frag = make_tensor(shape(tTR_rAcc)); + const int AlignmentS = 4; + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + Tensor tTR_rAcc_vec = recast>(tTR_rAcc); + Tensor tTR_rS_vec = recast>(tTR_rS_frag); + + // load s + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + if (is_last_tile) { + for (int i = 0; i < size(tTR_rAcc); i++) { + if (get<1>(tTR_cS(i)) + TileShapeS{} * k_index >= get<1>(problem_shape)) { + tTR_rAcc(i) = -std::numeric_limits::infinity(); + } + } + } + + // max + ElementAcc row_max_new = row_max; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 1) { + row_max_new = ::fmax(row_max_new, tTR_rAcc(i)); + } + + // for 2x2 dp, reduce here + if constexpr (kWarpsInN > 1) { + shared_tensors.smem_exchange[threadIdx.x] = row_max_new; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_max_new = cutlass::max(row_max_new, shared_tensors.smem_exchange[peer_index]); + } + + // find correction factor + ElementAcc softmax_scale_log2 = mainloop_args.softmax_scale * static_cast(M_LOG2E); + correction_factor = ::exp2f(softmax_scale_log2 * (row_max - row_max_new)); + row_max = row_max_new; + + // softmax + ElementAcc row_max_scale_log2 = row_max * softmax_scale_log2; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rAcc(i) = ::exp2f(softmax_scale_log2 * tTR_rAcc(i) - row_max_scale_log2); + } + + // quantize + cutlass::NumericArrayConverter epilogue_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc_vec); i++) { + tTR_rS_vec(i) = epilogue_op(tTR_rAcc_vec(i)); + } + + Tensor sP = make_tensor(make_smem_ptr((Element*) shared_tensors.smem_p.begin()), SmemLayoutP{})(_, _, _, make_coord(_, smem_p_index)); + + Tensor tOcP = TiledMmaPV{}.get_slice(_0{}).partition_A(cS); + + // have a mapping for each thread to coord + // find identical mapping to coords for the MMA + auto l = make_ordered_layout(make_shape(make_shape(_64{}, _2{}), make_shape(_16{}, TileShapeS{} / _32{})), make_stride(make_stride(_0{}, _3{}), make_stride(_1{}, _2{}))); + auto sP_ = as_position_independent_swizzle_tensor(sP); + copy_aligned(tTR_rS_frag, sP_.compose(l)(threadIdx.x, _)); + + // sum + row_sum *= correction_factor; + + static_assert(cute::is_same_v); + auto tTR_rAcc_float2 = recast(tTR_rAcc); + auto sums = make_tensor(_4{}); + static_assert(size(tTR_rAcc_float2) % size(sums) == 0); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(sums); i++) { + sums(i) = tTR_rAcc_float2(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = size(sums); i < size(tTR_rAcc_float2); i += size(sums)) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j++) { + cute::add(sums(j), sums(j), tTR_rAcc_float2(i + j)); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < size(sums); i *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(sums); j += 2*i) { + cute::add(sums(j), sums(j), sums(j+i)); + } + } + row_sum += sums(0).x + sums(0).y; + } + + + CUTLASS_DEVICE void rescale( + ElementAcc correction_factor, + uint32_t tmem_o) { + + // for b2b gemm, do nothing + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + auto store_op = TMEM::tmem_load_to_store(load_op); + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = partition_fragment_C(tiled_mma_pv, select<0,1>(TileShapePV{})); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = make_tensor(make_gmem_ptr((ElementAcc*) nullptr), cta_tiler_pv, make_stride(0, 0)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto tiled_r2t = make_tmem_copy(store_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + auto thread_r2t = tiled_r2t.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + // load o + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + // multiply by correction factor + float2 correction_factor_vec = make_float2(correction_factor, correction_factor); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i += 2) { + float2 in = make_float2(tTR_rAcc(i + 0), tTR_rAcc(i + 1)); + float2 out; + cute::mul(out, in, correction_factor_vec); + tTR_rAcc(i + 0) = out.x; + tTR_rAcc(i + 1) = out.y; + } + + // store o + copy(tiled_r2t, tTR_rAcc, tTR_tAcc); + } + + + template + CUTLASS_DEVICE void epilogue( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + uint32_t tmem_o, + int const& split_kv) { + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + tOtO.data() = tmem_o; + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + if (split_kv > 1) { + using ElementOutAcc = ElementAcc; + constexpr auto AlignmentOutAcc = 128 / cute::sizeof_bits_v; + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o_acc + get<3>(cta_coord) * D_latent), make_shape(H, D_latent, B), epilogue_args.stride_o_acc); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + + if (get<1>(cta_coord) == 0) { + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_acc + H * get<3>(cta_coord)), make_shape(H, B), epilogue_args.stride_lse_acc); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } + } + } + else { + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + Tensor gO = local_tile(mO, cta_tiler_pv, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + + Tensor tTR_rO_frag = make_tensor(shape(tTR_rAcc)); + Tensor tTR_rO_src = recast>(coalesce(tTR_rO_frag)); + Tensor tR2G_rO_dst = recast>(coalesce(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + cutlass::epilogue::thread::LinearCombination epilogue_op({epilogue_args.output_scale / row_sum}); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_rO_frag(i) = epilogue_op(tTR_rAcc(i)); + } + + copy(tTR_rO_src, tR2G_rO_dst); + + if (get<1>(cta_coord) == 0) { + if (epilogue_args.ptr_lse != nullptr) { + // compute LSE + ElementAcc lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + + // store LSE + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // for 2x2 dp, this must be conditional and the index is wrong + if (! kIs2Sm || (threadIdx.x < 64)) + { + gLSE(threadIdx.x) = lse; + } + } + } + } + } + + template + CUTLASS_DEVICE ElementLSE epilogue_lse_reduction( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + int const& local_split_kv) { + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + auto cta_tiler_pv = take<0,2>(typename CollectiveMmaPV::CtaShape_MNK{}); + + constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; + using Sync = cutlass::detail::NamedBarrierSync; + + auto wait = [](int* lock, int count) { + __threadfence(); + if (threadIdx.x == 0) { + atomicAdd(lock, 1); + while (atomicCAS(lock, count, count) != count) {}; + } + __threadfence(); + Sync::sync(); + }; + + const ElementLSE lse = cutlass::fast_log(row_sum) + mainloop_args.softmax_scale * row_max; + Tensor mLSE_max_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_max_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE_max_buff = local_tile(mLSE_max_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + int* local_lock = epilogue_args.ptr_lock + get<0>(cta_coord) + 2 * get<2>(cta_coord); + + if (! kIs2Sm || (threadIdx.x < 64)) { + atomicMax(&(gLSE_max_buff(threadIdx.x)), __float2int_rn(lse)); + } + wait(local_lock, local_split_kv); + + auto global_lse_max = static_cast(gLSE_max_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x)); + + Tensor mLSE_buff = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse_exchange_buff), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE_buff = local_tile(mLSE_buff, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + if (! kIs2Sm || (threadIdx.x < 64)) { + atomicAdd(&(gLSE_buff(threadIdx.x)), expf(lse - global_lse_max)); + } + wait(local_lock, 2*local_split_kv); + + const auto sum_lse = gLSE_buff(kIs2Sm ? threadIdx.x % 64 : threadIdx.x); + const auto global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits::infinity() : + cutlass::fast_log(sum_lse) + global_lse_max; + const auto lse_scale = expf(lse - global_lse); + + if (epilogue_args.ptr_lse != nullptr) { + Tensor mLSE = make_tensor(make_gmem_ptr(epilogue_args.ptr_lse), make_shape(H, B), epilogue_args.stride_lse); + Tensor gLSE = local_tile(mLSE, append<3>(cta_tiler_pv, _1{}), take<0,3>(cta_coord), Step<_1, Underscore, _1>{}); + + // write out the global LSE + if (! kIs2Sm || (threadIdx.x < 64)) { + gLSE(threadIdx.x) = global_lse; + } + } + return lse_scale; + } + + + template + CUTLASS_DEVICE void epilogue_reduction( + ElementAcc& row_max, + ElementAcc& row_sum, + BlkCoord const& blk_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + int const& local_split_kv, + ElementLSE const& lse_scale) { + + constexpr int kNumThreads = kNumComputeWarps * NumThreadsPerWarp; + using Sync = cutlass::detail::NamedBarrierSync; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + auto load_op = cute::SM100_TMEM_LOAD_32dp32b32x{}; + + TiledMmaPV tiled_mma_pv; + Tensor tOtO = TiledMmaPV::make_fragment_C(partition_shape_C(TiledMmaPV{}, take<0, 2>(TileShapePV{}))); + + CUTE_STATIC_ASSERT_V(shape<1>(tOtO) == _1{}); + CUTE_STATIC_ASSERT_V(shape<2>(tOtO) == _1{}); + + using EpilogueLinearCombination = cutlass::epilogue::thread::LinearCombination; + EpilogueLinearCombination epilogue_op({epilogue_args.output_scale / row_sum * lse_scale}); + + CUTLASS_PRAGMA_UNROLL + for(int k = 0; k < IterationsPV_N; ++k) { + auto cta_coord = replace<1>(blk_coord, k); + + uint32_t tmem_o = uint32_t(TmemAllocation::kO0) + k * uint32_t(TmemAllocation::kSizeAccO); + tOtO.data() = tmem_o; + + Tensor tAcc = tOtO(make_coord(_,_),_0{},_0{}); + + Tensor mO = make_tensor(make_gmem_ptr(epilogue_args.ptr_o), make_shape(H, D_latent, B), epilogue_args.stride_o); + Tensor gO = local_tile(mO, TileShapeAcc{}, take<0,3>(cta_coord)); + + auto tiled_t2r = make_tmem_copy(load_op, tAcc); + auto thread_idx = threadIdx.x % size(tiled_t2r); + + auto thread_t2r = tiled_t2r.get_slice(thread_idx); + Tensor tTR_gO = thread_t2r.partition_D(gO); + Tensor tTR_rAcc = make_tensor(shape(tTR_gO)); + Tensor tTR_tAcc = thread_t2r.partition_S(tAcc); + + copy(tiled_t2r, tTR_tAcc, tTR_rAcc); + + Tensor sO = make_tensor(make_smem_ptr(reinterpret_cast(shared_tensors.smem_acc.begin())), SmemLayout{}); + Tensor tTR_sO = thread_t2r.partition_D(sO); + + Sync::sync(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tTR_rAcc); i++) { + tTR_sO(i) = epilogue_op(tTR_rAcc(i)); + } + tma_store_fence(); + Sync::sync(); + + auto tma_reduce_sum_per_cta = epilogue_args.tma_reduce_sum.get_slice(_0{}); + auto gmem_tensor_coord = epilogue_args.tma_reduce_sum.get_tma_tensor(shape(mO)); + auto gmem_tensor_coord_per_cta = local_tile(gmem_tensor_coord, TileShapeAcc{}, take<0,3>(cta_coord)); + if (threadIdx.x % kNumThreads == 0) { + copy(epilogue_args.tma_reduce_sum, + tma_reduce_sum_per_cta.partition_S(sO), + tma_reduce_sum_per_cta.partition_D(gmem_tensor_coord_per_cta)); + tma_store_arrive(); + } + tma_store_wait<0>(); + } + } + + template + CUTLASS_DEVICE void compute( + CtaCoord const& cta_coord, + ProblemShape const& problem_shape, + MainloopArguments const& mainloop_args, + EpilogueParams const& epilogue_args, + TensorStorage& shared_tensors, + PipelineS& pipeline_mma_s, + typename PipelineS::PipelineState& pipeline_mma_s_consumer_state, + PipelineP& pipeline_p_mma, + typename PipelineP::PipelineState& pipeline_p_mma_producer_state, + PipelineO& pipeline_mma_o, + typename PipelineO::PipelineState& pipeline_mma_o_consumer_state, + int const& split_kv, + bool const& is_fused_reduction) { + + auto [H, K, D, B] = problem_shape; + + int k_tile_total = ceil_div(K, TileShapeS{}); + int k_tile_per_cta = ceil_div(k_tile_total, split_kv); + int k_index = get<3>(cta_coord) * k_tile_per_cta; // lower limit + int k_tile_count = max(0, min(k_tile_total, k_index + k_tile_per_cta) - k_index); + if (k_tile_count == 0) { + + // if we return early, we have to make sure we release the load warp + cutlass::arch::NamedBarrier( + (kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, + kNamedBarrierEpilogue + ).arrive(); + + return; + } + int k_index_final = k_tile_total - 1; + + ElementAcc row_max = -std::numeric_limits::infinity(); + ElementAcc row_sum = 0; + ElementAcc correction_factor = 1; + + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + auto dispatch_bool = [](bool b, auto fn) { + if (b) { + fn(cute::true_type{}); + } + else { + fn(cute::false_type{}); + } + }; + + // softmax s0 -> p0 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + k_index += 1; + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + k_tile_count -= 1; + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + pipeline_p_mma.producer_acquire(pipeline_p_mma_producer_state); + pipeline_mma_s.consumer_wait(pipeline_mma_s_consumer_state); + + // softmax s1 -> p1 + dispatch_bool(k_index == k_index_final, [&](auto is_last_tile) { + softmax( + is_last_tile, + row_max, row_sum, correction_factor, + problem_shape, mainloop_args, shared_tensors, k_index, + uint32_t(pipeline_mma_s_consumer_state.index() == 0 ? TmemAllocation::kS0 : TmemAllocation::kS1), + pipeline_p_mma_producer_state.index() + ); + }); + + cutlass::arch::fence_view_async_tmem_load(); + cutlass::arch::fence_view_async_shared(); + pipeline_mma_s.consumer_release(pipeline_mma_s_consumer_state); + ++pipeline_mma_s_consumer_state; + pipeline_p_mma.producer_commit(pipeline_p_mma_producer_state); + ++pipeline_p_mma_producer_state; + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + // rescale + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + rescale(correction_factor, uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO)); + } + + cutlass::arch::fence_view_async_tmem_store(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + + --k_tile_count; + k_index += 1; + } + + pipeline_mma_o.consumer_wait(pipeline_mma_o_consumer_state); + + if constexpr (kWarpsInN > 1) { + // reduce row_sum if needed (for 2x2 dp) + shared_tensors.smem_exchange[threadIdx.x] = row_sum; + cutlass::arch::NamedBarrier(kNumComputeWarps*NumThreadsPerWarp, kNamedBarrierExchange).sync(); + // (64, 2) shape + int peer_index = (threadIdx.x + 64) % 128; + row_sum += shared_tensors.smem_exchange[peer_index]; + } + + cutlass::arch::NamedBarrier((kNumComputeWarps + kNumLoadWarps) * NumThreadsPerWarp, kNamedBarrierEpilogue).arrive(); + + const int actual_split_kv = ceil_div(k_tile_total, k_tile_per_cta); + if (!is_fused_reduction || actual_split_kv == 1) { + // epilogue + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < IterationsPV_N; j++) { + epilogue( + row_max, row_sum, + replace<1>(cta_coord, j), problem_shape, + mainloop_args, epilogue_args, shared_tensors, + uint32_t(TmemAllocation::kO0) + j * uint32_t(TmemAllocation::kSizeAccO), + actual_split_kv + ); + } + } else { + const ElementLSE lse_scale = + epilogue_lse_reduction( + row_max, row_sum, + cta_coord, + problem_shape, + mainloop_args, epilogue_args, + actual_split_kv + ); + + epilogue_reduction(row_max, row_sum, + cta_coord, + problem_shape, + mainloop_args, epilogue_args, + shared_tensors, + actual_split_kv, + lse_scale + ); + } + cutlass::arch::fence_view_async_tmem_load(); + pipeline_mma_o.consumer_release(pipeline_mma_o_consumer_state); + ++pipeline_mma_o_consumer_state; + } + +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..489f62f85368577a9df8471df19e98f4a23bc0c6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/kernel/sm100_mla_tile_scheduler.hpp @@ -0,0 +1,160 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaIndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z); + } + + CUTLASS_DEVICE + Sm100MlaIndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct Sm100MlaPersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_split_kv; + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemShape const& problem_shape, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, int const& split_kv) { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = size<0>(cluster_shape); + int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */; + num_blocks *= split_kv; /* Maximum Split KV*/ + + return Params { + num_blocks, + { num_m_blocks}, { get<3>(problem_shape) }, {split_kv}, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, n_split_kv; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_split_kv(block_decode, n_split_kv, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + return make_coord(m_block, _0{}, bidb, n_split_kv); + } + + CUTLASS_DEVICE + Sm100MlaPersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..465a9871ab8040c17ac78000d6f580a7045f8e2d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,410 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" + +using namespace cutlass::fmha::collective; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + TensorDQ mDQ_in, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{},idx2crd(idx_L, get<4>(problem_shape_in))) + ); + // problem_shape = problem_shape_in; + // offset = repeat_like(problem_shape_in, _0{}); + auto mQ = domain_offset(select<0,2,4>(offset), mQ_in); + auto mK = domain_offset(select<1,2,4>(offset), mK_in); + auto mV = domain_offset(select<1,3,4>(offset), mV_in); + auto mO = domain_offset(select<0,3,4>(offset), mO_in); + auto mLSE = domain_offset(select<0,4>(offset), mLSE_in); + auto mDO = domain_offset(select<0,3,4>(offset), mDO_in); + auto mDQ = domain_offset(select<0,2,4>(offset), mDQ_in); + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<2>(problem_shape); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + // acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + // acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } // for idx_D0 + + for (int idx_D1 = 0; idx_D1 < size<3>(problem_shape); idx_D1++) { + acc_dov += mDO(idx_Q, idx_D1, idx_L) * mV(idx_K, idx_D1, idx_L); + acc_doo += mDO(idx_Q, idx_D1, idx_L) * mO(idx_Q, idx_D1, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_K + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<2>(problem_shape); idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + ElementAcc rK = mK(idx_K, idx_D, idx_L); + ElementAcc rDS = mS[idx_K]; + acc += rDS * rK; + } + mDQ(idx_Q, idx_D, idx_L) = static_cast(acc); + } // for idx_D + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, */ TensorDK mDK_in, /* TensorDV mDV_in, */ + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) + ); + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; + + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mDK_in); + + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dk = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + ElementAcc acc_dov = 0; + ElementAcc acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 + + for (int idx_D1 = 0; idx_D1 < D_VO; idx_D1++) { + ElementAcc rDO = mDO(idx_Q, idx_D1, coord_HB); + ElementAcc rV = mV(idx_K, idx_D1, coord_HB); + ElementAcc rO = mO(idx_Q, idx_D1, coord_HB); + acc_dov += rDO * rV; + acc_doo += rDO * rO ; + } + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB)) * softmax_scale * (acc_dov - acc_doo)); + } // for idx_Q + + __syncthreads(); + + int idx_D = threadIdx.x; + if (idx_D < D) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rQ = mQ(idx_Q, idx_D, coord_HB); + ElementAcc rDS = mS[idx_Q]; + acc_dk += rDS * rQ; + } + } + __syncthreads(); + } // for idx_H_R + + int idx_D = threadIdx.x; + if (idx_D < D) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDK(idx_K, idx_D, coord_HB) = static_cast(acc_dk); + } + } // for idx_K + } // for idx_HB +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape_in, + TensorQ mQ_in, TensorK mK_in, TensorV mV_in, + TensorO mO_in, TensorLSE mLSE_in, TensorDO mDO_in, + /* TensorDQ mDQ_in, TensorDK mDK_in, */ TensorDV mDV_in, + Fusion fusion) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAcc softmax_scale = 1.0f / sqrtf(size<2>(problem_shape_in)); + + auto [H, B] = get<4>(problem_shape_in); + auto [H_R, H_K] = H; + + for (int idx_HB = blockIdx.y; idx_HB < H_K * B; idx_HB += gridDim.y) { + auto [idx_H_K, idx_B] = idx2crd(idx_HB, make_shape(H_K, B)); + auto [problem_shape, offset] = apply_variable_length_offset( + problem_shape_in, + make_coord(_0{}, _0{}, _0{}, _0{}, make_coord(make_coord(_0{}, idx_H_K), idx_B)) + ); + auto [Q, K, D, D_VO, HB] = problem_shape; + auto [offset_Q, offset_K, offset_D, offset_D_VO, offset_HB] = offset; + + auto mQ = domain_offset(make_coord(offset_Q, offset_D, offset_HB), mQ_in); + auto mK = domain_offset(make_coord(offset_K, offset_D, offset_HB), mK_in); + auto mV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mV_in); + auto mO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mO_in); + auto mLSE = domain_offset(make_coord(offset_Q, offset_HB), mLSE_in); + auto mDO = domain_offset(make_coord(offset_Q, offset_D_VO, offset_HB), mDO_in); + auto mDV = domain_offset(make_coord(offset_K, offset_D_VO, offset_HB), mDV_in); + + for (int idx_K = blockIdx.x; idx_K < K; idx_K += gridDim.x) { + ElementAcc acc_dv = 0; + for (int idx_H_R = 0; idx_H_R < H_R; idx_H_R++) { + auto coord_HB = make_coord(make_coord(idx_H_R, idx_H_K), idx_B); + for (int idx_Q = threadIdx.x; idx_Q < Q; idx_Q += blockDim.x) { + ElementAcc acc_qk = 0; + + for (int idx_D0 = 0; idx_D0 < D; idx_D0++) { + ElementAcc rQ = mQ(idx_Q, idx_D0, coord_HB); + ElementAcc rK = mK(idx_K, idx_D0, coord_HB); + acc_qk += rQ * rK; + } // for idx_D0 + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(expf(softmax_scale * acc_qk - mLSE(idx_Q, coord_HB))); + } // for idx_Q + + __syncthreads(); + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + for (int idx_Q = 0; idx_Q < Q; idx_Q++) { + ElementAcc rDO = mDO(idx_Q, idx_D_VO, coord_HB); + ElementAcc rP = mS[idx_Q]; + acc_dv += rP * rDO; + } + } // for idx_D + + __syncthreads(); + } // for idx_H_R + + int idx_D_VO = threadIdx.x; + if (idx_D_VO < D_VO) { + auto coord_HB = make_coord(make_coord(0, idx_H_K), idx_B); + mDV(idx_K, idx_D_VO, coord_HB) = static_cast(acc_dv); + } + } // for idx_K + } // for idx_L +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorDQ::value_type); + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion) { + + using namespace cute; + + auto [K, D, HB] = mDK.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDK::value_type); + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion) { + + using namespace cute; + + auto [K, D_VO, HB] = mDV.shape(); + auto [H, B] = HB; + auto [H_R, H_K] = H; + dim3 grid(K, H_K * B, 1); + dim3 block(std::max(D_VO, 256)); + int shared_mem = size<0>(mDO) * sizeof(typename TensorDV::value_type); + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion) { + + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40239c564b001b6571bda5efcc151395d239f4f1 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_gen_reference.hpp @@ -0,0 +1,185 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void __global__ fmha_fwd_gen_reference_kernel( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + extern __shared__ char mS_mem[]; + ElementAcc* mS = reinterpret_cast(mS_mem); + + float scale = 1.0f / std::sqrt(float(get<2>(problem_shape))); + + if (mNewK.data() != nullptr) { + // 1. copy in new_k to cache + for (int idx_h = blockIdx.x; idx_h < size<3,0,1>(problem_shape); idx_h += gridDim.x) { + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + for (int idx_d = threadIdx.x; idx_d < size<2>(problem_shape); idx_d += blockDim.x) { + mCacheK(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewK(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + mCacheV(seqlen_kv[idx_b], idx_d, make_coord(make_coord(_0{}, idx_h), idx_b_kv)) = + mNewV(_0{}, idx_d, make_coord(make_coord(_0{}, idx_h), idx_b)); + } + } + } + } + + // 2. compute attention + for (int idx_h_kv = blockIdx.x; idx_h_kv < size<3,0,1>(problem_shape); idx_h_kv += gridDim.x) { + for (int idx_h_qo = blockIdx.y; idx_h_qo < size<3,0,0>(problem_shape); idx_h_qo += gridDim.y) { + int idx_h = idx_h_qo + size<3,0,0>(problem_shape) * idx_h_kv; + for (int idx_b = blockIdx.z; idx_b < size<3,1>(problem_shape); idx_b += gridDim.z) { + int idx_b_kv = cache_batch_idx != nullptr ? cache_batch_idx[idx_b] : idx_b; + const int kDim = 128; + ElementAcc reg_o[kDim] = {0}; + ElementAcc row_max = -INFINITY; + ElementAcc row_sum = 0; + auto iteration = [&](auto const& tK, auto const& tV) { + ElementAcc reg_s = 0; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eQ = mQ(_0{}, idx_d, make_coord(idx_h, idx_b)); + ElementAcc eK = tK(idx_d); + reg_s += eQ * eK; + } + + ElementAcc old_row_max = row_max; + row_max = std::max(row_max, reg_s); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + + ElementAcc reg_p = std::exp(scale * (reg_s - row_max)); + row_sum += reg_p; + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + ElementAcc eV = tV(idx_d); + reg_o[idx_d] += reg_p * eV; + } + }; + + for (int idx_s = threadIdx.x; idx_s < seqlen_kv[idx_b]; idx_s += blockDim.x) { + iteration(mCacheK(idx_s, _, make_coord(idx_h, idx_b_kv)), mCacheV(idx_s, _, make_coord(idx_h, idx_b_kv))); + } + + if (mNewK.data() != nullptr && threadIdx.x == 0) { + iteration(mNewK(_0{}, _, make_coord(idx_h, idx_b)), mNewV(_0{}, _, make_coord(idx_h, idx_b))); + } + + mS[threadIdx.x] = row_max; + __syncthreads(); + float old_row_max = row_max; + for (int i = 0; i < blockDim.x; i++) { + row_max = std::max(row_max, mS[i]); + } + __syncthreads(); + + ElementAcc adjustment = std::exp(scale * (old_row_max - row_max)); + row_sum *= adjustment; + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] *= adjustment; + } + mS[threadIdx.x] = row_sum; + __syncthreads(); + + row_sum = 0; + for (int i = 0; i < blockDim.x; i++) { + row_sum += mS[i]; + } + __syncthreads(); + for (int idx_d = 0; idx_d < kDim; idx_d++) { + mS[idx_d] = 0; + } + __syncthreads(); + + for (int idx_d = 0; idx_d < kDim; idx_d++) { + reg_o[idx_d] /= row_sum; + atomicAdd(&mS[idx_d], reg_o[idx_d]); + } + + __syncthreads(); + for (int idx_d = threadIdx.x; idx_d < kDim; idx_d += blockDim.x) { + mO(_0{}, idx_d, make_coord(idx_h, idx_b)) = static_cast(mS[idx_d]); + } + } + } + } +} + +template< + class ElementAcc, + class ProblemShape, + class TensorQ, + class TensorNewK, + class TensorNewV, + class TensorCacheK, + class TensorCacheV, + class TensorO +> +void fmha_fwd_gen_reference( + ProblemShape problem_shape, + const int* seqlen_kv, const int* cache_batch_idx, + TensorQ mQ, TensorNewK mNewK, TensorNewV mNewV, + TensorCacheK mCacheK, TensorCacheV mCacheV, TensorO mO) { + + using namespace cute; + + dim3 grid(get<3,0,1>(problem_shape), get<3,0,0>(problem_shape), get<3,1>(problem_shape)); + dim3 block(128); + int shared_mem = int(sizeof(ElementAcc)) * std::max(128, block.x); + assert(get<2>(problem_shape) == 128); + fmha_fwd_gen_reference_kernel<<>>( + problem_shape, seqlen_kv, cache_batch_idx, + mQ, mNewK, mNewV, mCacheK, mCacheV, mO + ); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d674ee95d2a7e30ef9c870506aaa414066e3df70 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_fwd_reference.hpp @@ -0,0 +1,195 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "collective/fmha_fusion.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void __global__ fmha_reference_kernel( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + using namespace cutlass::fmha::collective; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + ElementAccumulator* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mQ))); + + auto id = make_identity_tensor(make_shape(1, 1)); + + for (int idx_L = blockIdx.y; idx_L < size<4>(problem_shape_in); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(problem_shape_in); idx_Q += gridDim.x) { + + auto coord_L = idx2crd(idx_L, shape<4>(problem_shape_in)); + auto get_coord_in = [&]() { + if constexpr (rank_v(ProblemShapeIn{}))> == 2) { + return cute::make_tuple(idx_Q, _0{}, cute::make_tuple(_0{}, _0{}), cute::make_tuple(_0{}, _0{}), coord_L); + } else { + return cute::make_tuple(idx_Q, _0{}, _0{}, _0{}, coord_L); + } + }; + auto coord_in = get_coord_in(); + auto [problem_shape, coord] = apply_variable_length(problem_shape_in, coord_in, get<4,1>(coord_in)); + + int head_qk = 0; + int head_v = 0; + if constexpr (rank_v(problem_shape))> == 2) { + // MLA case: head_qk 192, head_v = 128 + head_qk = size<2, 0>(problem_shape) + size<2, 1>(problem_shape); + head_v = size<2, 0>(problem_shape); + } else { + head_qk = size<3>(problem_shape); + head_v = head_qk; + } + + if (get<0,0>(coord) >= get<0>(problem_shape)) continue; + + int offset_Q = 0; + if constexpr (rank<0>(decltype(coord){}) == 2) { + offset_Q = get<0,1>(coord); + } + + int offset_K = 0; + if constexpr (rank<1>(decltype(coord){}) == 2) { + offset_K = get<1,1>(coord); + } + + if (get<1>(problem_shape) == 0) { + for (int idx_D = threadIdx.x; idx_D < head_qk; idx_D += blockDim.x) { + mO(idx_Q + offset_Q, idx_D, idx_L) = Element(0); + } + + if (threadIdx.x == 0 && mLSE.data() != nullptr) { + mLSE(idx_Q + offset_Q, idx_L) = -INFINITY; + } + continue; + } + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < head_qk; idx_D++) { + ElementAccumulator eQ = mQ(idx_Q + offset_Q, idx_D, idx_L); + ElementAccumulator eK = mK(idx_K + offset_K, idx_D, idx_L); + acc += eQ * eK; + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + mask.apply_mask(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = frag(0); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<1>(problem_shape); idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + sum += mS[idx_K]; + } + + ElementAccumulator scale = 1.0f / sum; + + + for (int idx_D = threadIdx.x; idx_D < head_v; idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<1>(problem_shape); idx_K++) { + ElementAccumulator eV = mV(idx_K + offset_K, idx_D, idx_L); + ElementAccumulator eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_Q + offset_Q, idx_D, idx_L) = static_cast(acc * scale); + } + + + if (threadIdx.x == 0 && mLSE.data() != nullptr) { + mLSE(idx_Q + offset_Q, idx_L) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeIn, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Mask +> +void fmha_reference( + ProblemShapeIn problem_shape_in, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Mask mask) { + + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * int(sizeof(typename TensorLSE::value_type)); + fmha_reference_kernel<<>>(problem_shape_in, mQ, mK, mV, mO, mLSE, mask); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c83ebdb7479159943c8edc0974c998d37052405e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/fmha_mla_reference.hpp @@ -0,0 +1,201 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void __global__ fmha_mla_reference_kernel( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale softmax_scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + using Element = typename TensorO::value_type; + using ElementAcc = typename TensorLSE::value_type; + + extern __shared__ ElementAcc mS[]; + // ElementAcc* mS = reinterpret_cast(mS_mem); + + for (int idx_B = blockIdx.y; idx_B < B; idx_B += gridDim.y) { + if (mSeq.data() != nullptr) { + K = mSeq(idx_B); + } + + for (int idx_H = blockIdx.x; idx_H < H; idx_H += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + ElementAcc acc = 0; + + for (int idx_D = 0; idx_D < D_latent; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQL(idx_H, idx_D, idx_B); + ElementAcc eK = mCL(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + + for (int idx_D = 0; idx_D < D_rope; idx_D++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eQ = mQR(idx_H, idx_D, idx_B); + ElementAcc eK = mKR(page_idx_K, idx_D, page_idx_B); + acc += eQ * eK; + } + mS[idx_K] = acc; + } + + __syncthreads(); + + ElementAcc maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < K; idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < K; idx_K += blockDim.x) { + mS[idx_K] = expf(softmax_scale * (mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAcc sum = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + sum += mS[idx_K]; + } + + ElementAcc o_scale = 1.0f / sum; + + for (int idx_D = threadIdx.x; idx_D < D_latent; idx_D += blockDim.x) { + ElementAcc acc = 0; + for (int idx_K = 0; idx_K < K; idx_K++) { + int page_idx_K = idx_K; + int page_idx_B = idx_B; + if (mPT.data() != nullptr) { + page_idx_B = mPT(idx_K / size<0>(mCL), idx_B); + page_idx_K = idx_K % size<0>(mCL); + } + ElementAcc eV = mCL(page_idx_K, idx_D, page_idx_B); + ElementAcc eK = static_cast(mS[idx_K]); + acc += eK * eV; + } + mO(idx_H, idx_D, idx_B) = static_cast(acc * o_scale); + } + + if (threadIdx.x == 0) { + mLSE(idx_H, idx_B) = log(sum) + softmax_scale * maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorSeq, + class TensorPageTable, + class TensorQL, + class TensorQR, + class TensorCL, + class TensorKR, + class TensorO, + class TensorLSE, + class Scale +> +void fmha_mla_reference( + ProblemShape problem_shape, + TensorSeq mSeq, TensorPageTable mPT, + TensorQL mQL, TensorQR mQR, + TensorCL mCL, TensorKR mKR, + TensorO mO, TensorLSE mLSE, + Scale scale) { + + using namespace cute; + + auto [H, K, D, B] = problem_shape; + auto [D_latent, D_rope] = D; + + dim3 grid(H, B, 1); + dim3 block(256); + int shared_mem = K * int(sizeof(typename TensorLSE::value_type)) + 16; + cudaError_t result; + if (shared_mem >= (48 << 10)) { + result = cudaFuncSetAttribute( + &fmha_mla_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + throw std::runtime_error("couldn't perform smem optin"); + } + } + fmha_mla_reference_kernel<<>>( + problem_shape, mSeq, mPT, mQL, mQR, mCL, mKR, mO, mLSE, scale); + cudaDeviceSynchronize(); + result = cudaGetLastError(); + if (cudaSuccess != result) { + throw std::runtime_error("couldn't execute reference"); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8c29289c989931970fafd507a4e42f0502cbd2eb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/77_blackwell_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,282 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DeviceAllocation { + T* ptr_ = nullptr; + size_t offset_ = 0; + size_t size_ = 0; + + DeviceAllocation(DeviceAllocation const&) = delete; + DeviceAllocation& operator=(DeviceAllocation const&) = delete; + + DeviceAllocation() = default; + DeviceAllocation(size_t size) { reset(size); } + ~DeviceAllocation() { reset(); } + + void reset(size_t size, size_t offset=0) { + reset(); + auto ret = cudaMalloc(&ptr_, sizeof(T) * (size + offset)); + assert(ret == cudaSuccess); + size_ = size; + offset_ = offset; + } + + T* get() { + return ptr_ + offset_; + } + + const T* get() const { + return ptr_ + offset_; + } + + void reset() { + if (ptr_ != nullptr) { + auto ret = cudaFree(ptr_); + assert(ret == cudaSuccess); + } + } + + size_t size() const { return size_; } + + size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); } + + void copy_from_host(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } + + void copy_from_device(const T* ptr, size_t sz) { + auto ret = cudaMemcpy(ptr_, ptr, sz * sizeof(T), cudaMemcpyDefault); + assert(ret == cudaSuccess); + } +}; + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + if (data[i] == data_ref[i]) { + continue; + } + + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} + +template +__global__ void reference_rel_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff ) { + + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + if (data[i] == data_ref[i]) { + continue; + } + double diff = fabs(data[i] - data_ref[i]) / fabs(data_ref[i]); + if (print_diff) if (not isfinite(diff) || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } + else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_rel_diff( + DeviceAllocation const& data, + DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff) { + + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_rel_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f26e6be8246e4f0318d6ce4973beac1e50dafa5b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/86_blackwell_mixed_dtype_gemm/mixed_dtype_helper.cuh @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + bool verify = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 1000; + int warmup = 1000; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + if (cmd.check_cmd_line_flag("verify")) { + verify = true; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "86_blackwell_mixed_dtype_gemm\n\n" + << " Blackwell Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n" + << " --verify= Run verification.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "86_blackwell_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); + } else { + // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + } + return true; +} + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..72735303883bba3d185e8d56df06273f33938ad5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/87_blackwell_geforce_gemm_blockwise/utils.h @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_input == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + cutlass::reference::host::BlockFillSequential(view.data(), view.capacity()); + } + else { + throw std::runtime_error("Not implementated."); + } + + return true; +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d9d311a54e61ede4407b2ab9320b85f73d15ef8f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_bwd_tma_warpspecialized.hpp @@ -0,0 +1,863 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBwdMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + static constexpr bool kIsPersistent = false; + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = 2; + static constexpr int StageCountQ = 2 /*K, V*/ * NumMmaWarpGroups; + static constexpr int StageCount = 2 /*Q, dO*/ * 2 /* actual stages */; + + static const int kOuterLoads = 2; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= 2); + static_assert(Stages::value >= 2 * NumMmaWarpGroups); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeNM = Shape< // (N,M,D) + decltype(tuple_element_t<1, TileShape>{} / Int{}), + tuple_element_t<0, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapeND = decltype(select<0,2,1>(TileShapeNM{})); // (N,D,M) + + using TileShapeMD = decltype(select<2,1,0>(TileShapeND{})); // (M,D,N) + + using CollectiveMmaNM = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeNM, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaND_SS = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple>, Alignment, // from register, doesn't matter + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeND, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + + using CollectiveMmaMD = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, cute::tuple<_1, int, cute::tuple>, Alignment, // from smem, might matter (?) + Element, cute::tuple<_1, int, cute::tuple>, Alignment, + ElementAccumulator, + TileShapeMD, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaNM = typename CollectiveMmaNM::TiledMma; + using TiledMmaND_SS = typename CollectiveMmaND_SS::TiledMma; + using TiledMmaND_RS = decltype(convert_to_gmma_rs(typename CollectiveMmaND::TiledMma{})); + using TiledMmaND = TiledMmaND_RS; + using TiledMmaMD = typename CollectiveMmaMD::TiledMma; + + using SmemLayoutQ = typename CollectiveMmaNM::SmemLayoutB; + using SmemLayoutK = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutV = typename CollectiveMmaNM::SmemLayoutA; + using SmemLayoutDO = typename CollectiveMmaNM::SmemLayoutB; + + //using SmemLayoutDQ = Layout< + // Shape< + // tuple_element_t<0, TileShapeMD>, + // Shape<_2, _4, decltype(tuple_element_t<1, TileShapeMD>{} / _8{})>, + // _2 + // >, + // Stride< + // _4, + // Stride{} * _4{}), _1, decltype(tuple_element_t<0, TileShapeMD>{} * _8{})>, + // decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + // >>; + + using SmemLayoutDQ_0 = Layout< + Shape< + tuple_element_t<0, TileShapeMD>, + tuple_element_t<1, TileShapeMD>, + _2 + >, + Stride< + tuple_element_t<1, TileShapeMD>, + _1, + decltype(tuple_element_t<0, TileShapeMD>{} * tuple_element_t<1, TileShapeMD>{}) + >>; + + using SmemAtomDQ = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K, ElementAccumulator, tuple_element_t<0, TileShapeMD>, tuple_element_t<1, TileShapeMD>>()); + using SmemLayoutDQ_1 = decltype(tile_to_shape(SmemAtomDQ{}, make_shape(get<0>(TileShapeMD{}), get<1>(TileShapeMD{}), _2{}), Step<_2, _1, _3>{})); + using SmemLayoutDQ = SmemLayoutDQ_1; + + + using PipelineDQ = cutlass::PipelineAsync<2>; + + + using SmemLayoutDS_0 = decltype(unstageSmemLayout(typename CollectiveMmaMD::SmemLayoutA{}, Int{})); + + using SmemLayoutDS = decltype(tile_to_shape(GMMA::Layout_MN_INTER_Atom{}, make_shape(size<0>(SmemLayoutDS_0{}), size<1>(SmemLayoutDS_0{}), size<2>(SmemLayoutDS_0{})), Step<_1, _2, _3>{})); + using SmemLayoutKp = typename CollectiveMmaMD::SmemLayoutB; + + using SmemLayoutQp = typename CollectiveMmaND::SmemLayoutB; + using SmemLayoutDOp = typename CollectiveMmaND::SmemLayoutB; + + using SmemLayoutLSE = Layout, Int>>; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapePV = TileShapeND; // To work with the kernel level + using TiledMmaPV = TiledMmaND; + + static constexpr int kInnerLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element) + size(SmemLayoutLSE{}(_,_0{})) * sizeof(ElementAccumulator); + static constexpr int kOuterLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + + struct SharedStorage { + // One for each consumer WG + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_kp; + cute::array_aligned> smem_v; + }; + + cute::array_aligned> smem_ds; + + // Loaded by producer, consumed by both WGs + union { + cute::array_aligned> smem_q; + cute::array_aligned> smem_do; + cute::array_aligned> smem_qp; + cute::array_aligned> smem_dop; + }; + + // Accumulated into by both consumers, potentially loaded, potentially written + cute::array_aligned> smem_dq; + + union { + cute::array_aligned> smem_lse; + cute::array_aligned> smem_sumOdO; + }; + }; + + struct Arguments { + const Element* ptr_Q; + cute::tuple dQ; + const Element* ptr_K; + cute::tuple dK; + const Element* ptr_V; + cute::tuple dV; + + const Element* ptr_dO; + cute::tuple dDO; + + const ElementAccumulator* ptr_LSE; + cute::tuple dLSE; + const ElementAccumulator* ptr_sum_OdO; + cute::tuple dSumOdO; + + ElementAccumulator* ptr_dQ; + cute::tuple dDQ; + }; + + using TMA_Q = typename CollectiveMmaNM::Params::TMA_B; + using TMA_K = typename CollectiveMmaNM::Params::TMA_A; + using TMA_V = typename CollectiveMmaNM::Params::TMA_A; + using TMA_DO = typename CollectiveMmaNM::Params::TMA_B; + + using TMA_LSE = decltype(make_tma_copy(SM90_TMA_LOAD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1), make_stride(_1{}, 0, 0)), SmemLayoutLSE{}(_,_0{}))); + using TMA_ODO = TMA_LSE; + + using TMA_DQ = decltype(make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor((const ElementAccumulator*)nullptr, make_shape(1, 1, 1, 1), make_stride(0, _1{}, 0, 0)), SmemLayoutDQ{}(_,_,_0{}))); + + using LoadQ = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = CollectiveLoadTma< + LoadKind::kBwdN, + MainloopPipelineQ, + Element, + SmemLayoutV, + TMA_V + >; + + using LoadDO = CollectiveLoadTma< + LoadKind::kBwdM, + MainloopPipeline, + Element, + SmemLayoutDO, + TMA_DO + >; + + using LoadLSE = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_LSE + >; + + using LoadODO = CollectiveLoadTma< + LoadKind::kBwdScalar, + MainloopPipeline, + ElementAccumulator, + SmemLayoutLSE, + TMA_ODO + >; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + TMA_DO tma_load_do; + + TMA_LSE tma_load_lse; + TMA_ODO tma_load_odo; + + TMA_DQ tma_red_dq; + + float scale_softmax; + float scale_softmax_log2; + }; + + static_assert(size(TiledMmaNM{}) == size(TiledMmaND{})); + static_assert(size(TiledMmaNM{}) == size(TiledMmaMD{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + auto problem_shape_nm = make_shape(get<3>(problem_size), get<2>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dQ = make_stride(get<2>(args.dQ), get<3>(args.dQ), make_stride(get<0>(args.dQ), get<1>(args.dQ))); + auto params_nm_kq = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_K, dK, + args.ptr_Q, dQ, + }, /*workspace=*/ nullptr); + + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), make_stride(get<0>(args.dV), get<1>(args.dV))); + auto dDO = make_stride(get<2>(args.dDO), get<3>(args.dDO), make_stride(get<0>(args.dDO), get<1>(args.dDO))); + auto params_nm_vdo = CollectiveMmaNM::to_underlying_arguments(problem_shape_nm, + typename CollectiveMmaNM::Arguments { + args.ptr_V, dV, + args.ptr_dO, dDO, + }, /*workspace=*/ nullptr); + + + TMA_LSE tma_load_lse = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_LSE, select<2,0,1>(problem_size), select<2,0,1>(args.dLSE)), SmemLayoutLSE{}(_,_0{})); + TMA_ODO tma_load_odo = make_tma_copy(SM90_TMA_LOAD{}, make_tensor(args.ptr_sum_OdO, select<2,0,1>(problem_size), select<2,0,1>(args.dSumOdO)), SmemLayoutLSE{}(_,_0{})); + + TMA_DQ tma_red_dq = make_tma_copy(SM90_TMA_REDUCE_ADD{}, make_tensor(args.ptr_dQ, select<2,4,0,1>(problem_size), select<2,3,0,1>(args.dDQ)), SmemLayoutDQ{}(_,_,_0{})); + + return Params{ + params_nm_kq.tma_load_b, + params_nm_kq.tma_load_a, + params_nm_vdo.tma_load_a, + params_nm_vdo.tma_load_b, + tma_load_lse, tma_load_odo, + tma_red_dq, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))) + }; + } + + template + CUTLASS_DEVICE + auto + get_inner_tile_count(BlkCoord const& blk_coord, ProblemSize const& problem_size) { + return Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_do.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_odo.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_lse.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_write_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + auto inner_tile_iter = cute::make_coord_iterator(inner_tile_count); + + uint16_t mcast_mask_b = 0; + + LoadQ load_q{params.tma_load_q, pipeline_inner, storage.smem_q}; + auto load_state_q = load_q.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadDO load_do{params.tma_load_do, pipeline_inner, storage.smem_do}; + auto load_state_do = load_do.init_state(block_rank_in_cluster, problem_size, TileShapeNM{}, blk_coord, inner_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadLSE load_lse{params.tma_load_lse, pipeline_inner, storage.smem_lse}; + auto load_state_lse = load_lse.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadODO load_odo{params.tma_load_odo, pipeline_inner, storage.smem_sumOdO}; + auto load_state_odo = load_odo.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + inner_tile_count *= 4; // Q & dO & LSE & sumOdO + + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + + if constexpr (kLoadOuter) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (! kLoadOuter) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + if constexpr (kLoadOuter) { + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + + if constexpr (kLoadOuter) { + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(*inner_tile_iter, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 4; + ++inner_tile_iter; + } + load_q.template step(inner_tile_iter, load_state_q, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_lse.template step(inner_tile_iter, load_state_lse, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + + load_do.template step(inner_tile_iter, load_state_do, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + load_odo.template step(inner_tile_iter, load_state_odo, smem_pipe_write_inner, lane_predicate, inner_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_write_outer, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + // Load pattern: + // K0 V0 K1 V1 + // Q0 DO0 Q1 DO1 Q2 DO2 ... + // K0 Q0 V0 K1 DO0 V1 ... + int lane_predicate = cute::elect_one_sync(); + + int outer_tile_count = NumMmaWarpGroups; + + auto outer_tile_iter = cute::make_coord_iterator(outer_tile_count); + + LoadK load_k{params.tma_load_k, pipeline_outer, storage.smem_k}; + auto load_state_k = load_k.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + LoadV load_v{params.tma_load_v, pipeline_outer, storage.smem_v}; + auto load_state_v = load_v.init_state(_0{}, problem_size, TileShapeNM{}, blk_coord, outer_tile_count); + + outer_tile_count *= 2; // K & V + + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + + while (outer_tile_count > 0) { + load_k.template step(outer_tile_iter, load_state_k, smem_pipe_write_outer, lane_predicate, outer_tile_count); + load_v.template step(outer_tile_iter, load_state_v, smem_pipe_write_outer, lane_predicate, outer_tile_count); + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_read_reducer, + SharedStorage& storage) + { + int lane_predicate = cute::elect_one_sync(); + + Tensor mDQ_full = params.tma_red_dq.get_tma_tensor(select<2,4,0,1>(problem_size)); + Tensor gDQ_full = local_tile(mDQ_full, TileShapeMD{}, make_coord(_, _, _), Step<_1, _1, Underscore>{}); + Tensor gDQ = gDQ_full(_, _, _, _0{}, get<2,0>(blk_coord), get<2,1>(blk_coord)); + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + + auto block_tma = params.tma_red_dq.get_slice(_0{}); + + Tensor tDQsDQ = block_tma.partition_S(sDQ); + Tensor tDQgDQ = block_tma.partition_D(gDQ); + + int inner_tile_count = get_inner_tile_count(blk_coord, problem_size); + int g_index = 0; + + auto smem_pipe_release_reducer = smem_pipe_read_reducer; + bool first = true; + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(g_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + ++g_index; + } + if (inner_tile_count == 0) break; + + pipeline_reducer.consumer_wait(smem_pipe_read_reducer); + if (lane_predicate == 1) { + tma_store_wait<1>(); + } + if (! first) { + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } else { + first = false; + } + if (lane_predicate == 1) { + copy(params.tma_red_dq, tDQsDQ(_,_,_,smem_pipe_read_reducer.index()), tDQgDQ(_,_,_,g_index)); + tma_store_arrive(); + } + ++smem_pipe_read_reducer; + --inner_tile_count; + ++g_index; + } + if (lane_predicate) { + tma_store_wait<0>(); + } + pipeline_reducer.consumer_release(smem_pipe_release_reducer); + ++smem_pipe_release_reducer; + } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline_inner, PipelineState& smem_pipe_read_inner, + MainloopPipelineQ& pipeline_outer, PipelineStateQ& smem_pipe_read_outer, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + TiledMmaND tiled_mma_nd; + + Tensor acc_DV = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DV); + + Tensor acc_DK = partition_fragment_C(tiled_mma_nd, take<0,2>(TileShapeND{})); + clear(acc_DK); + + int thread_idx = int(threadIdx.x) % cutlass::NumThreadsPerWarpGroup; + + PipelineState smem_pipe_release_inner = smem_pipe_read_inner; + + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_k = smem_pipe_read_outer; + ++smem_pipe_read_outer; + pipeline_outer.consumer_wait(smem_pipe_read_outer); + PipelineStateQ smem_pipe_read_v = smem_pipe_read_outer; + + int inner_tile_count = get_inner_tile_count(wg_coord, problem_size); + + TiledMmaNM tiled_mma_nm; + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + auto thr_mma_nm = tiled_mma_nm.get_thread_slice(thread_idx); + Tensor tSsK = thr_mma_nm.partition_A(sK); + Tensor tSsQ = thr_mma_nm.partition_B(sQ); + Tensor tSrK = thr_mma_nm.make_fragment_A(tSsK); + Tensor tSrQ = thr_mma_nm.make_fragment_B(tSsQ); + + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + Tensor sDO = make_tensor(make_smem_ptr(storage.smem_do.data()), SmemLayoutDO{}); + + Tensor tDPsV = thr_mma_nm.partition_A(sV); + Tensor tDPsDO = thr_mma_nm.partition_B(sDO); + Tensor tDPrV = thr_mma_nm.make_fragment_A(tDPsV); + Tensor tDPrDO = thr_mma_nm.make_fragment_B(tDPsDO); + + auto thr_mma_nd = tiled_mma_nd.get_thread_slice(thread_idx); + + Tensor sDOp = make_tensor(make_smem_ptr(storage.smem_dop.data()), SmemLayoutDOp{}); + Tensor tDV_sDO = thr_mma_nd.partition_B(sDOp); + Tensor tDVrDO = thr_mma_nd.make_fragment_B(tDV_sDO); + + + Tensor sQp = make_tensor(make_smem_ptr(storage.smem_qp.data()), SmemLayoutQp{}); + Tensor tDK_sQ = thr_mma_nd.partition_B(sQp); + Tensor tDKrQ = thr_mma_nd.make_fragment_B(tDK_sQ); + + + int wg_idx = __shfl_sync(0xffffffff, get<1>(wg_coord) % NumMmaWarpGroups, 0); + + TiledMmaMD tiled_mma_md; + auto thr_mma_md = tiled_mma_md.get_thread_slice(thread_idx); + Tensor sDS = make_tensor(make_smem_ptr(storage.smem_ds.data()), SmemLayoutDS{}); + Tensor tDQsDS = thr_mma_md.partition_A(sDS); + Tensor tDQrDS_full = thr_mma_md.make_fragment_A(tDQsDS); + Tensor tDQrDS = tDQrDS_full(_,_,_,_); + Tensor sKp = make_tensor(make_smem_ptr(storage.smem_kp.data()), SmemLayoutKp{}); + Tensor tDQsK = thr_mma_md.partition_B(sKp); + Tensor tDQrK = thr_mma_md.make_fragment_B(tDQsK); + + Tensor sLSE = make_tensor(make_smem_ptr(storage.smem_lse.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tSsLSE = thr_mma_nm.partition_C(sLSE); + + Tensor sODO = make_tensor(make_smem_ptr(storage.smem_sumOdO.data()), make_shape(get<0>(TileShapeNM{}), get<1>(TileShapeNM{}), Int{}), make_stride(_0{}, _1{}, get<1>(TileShapeNM{}))); + Tensor tDPsODO = thr_mma_nm.partition_C(sODO); + + Tensor cS = make_identity_tensor(take<0,2>(TileShapeNM{})); + Tensor tScS = thr_mma_nm.partition_C(cS); + int n_block = get<1>(wg_coord); + tScS.data() = tScS.data() + E<0>{} * n_block * get<0>(TileShapeNM{}); + + + // Transpose + Tensor sDSp_full = sDS.compose(make_layout(make_shape(size<1>(sDS), size<0>(sDS), size<2>(sDS)), make_stride(size<0>(sDS), _1{}, size<1>(sDS) * size<0>(sDS)))); + Tensor sDSp = sDSp_full(_,_,_); + Tensor tDPsDS = thr_mma_nm.partition_C(sDSp); + + auto thr_mma_nd_ss = TiledMmaND_SS{}.get_thread_slice(thread_idx); + Tensor tDKsDSp = thr_mma_nd_ss.partition_A(sDSp); + + Tensor tDKrDSp = thr_mma_nd_ss.make_fragment_A(tDKsDSp); + + Tensor sDQ = make_tensor(make_smem_ptr(storage.smem_dq.data()), SmemLayoutDQ{}); + auto tDQsDQ_full = thr_mma_md.partition_C(sDQ); + + + auto smem_pipe_read_k_other = smem_pipe_read_k; + smem_pipe_read_k_other.advance(2); + + int k_index = 0; + + while (inner_tile_count > 0) { + while (inner_tile_count > 0) { + if (Fusion{}.is_contributing(make_coord(k_index, get<1>(blk_coord)), TileShape{}, problem_size)) { + break; + } + inner_tile_count -= 1; + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + if (inner_tile_count == 0) break; + + pipeline_inner.consumer_wait(smem_pipe_read_inner); + PipelineState smem_pipe_read_q = smem_pipe_read_inner; + ++smem_pipe_read_inner; + PipelineState smem_pipe_read_do = smem_pipe_read_inner; + ++smem_pipe_read_inner; + + // GEMM KQ -> S + Tensor acc_S = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_S); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tSrK(_,_,_,smem_pipe_read_k.index()), tSrQ(_,_,_,smem_pipe_read_q.index()), acc_S); + warpgroup_commit_batch(); + + pipeline_inner.consumer_wait(smem_pipe_read_do); + + // GEMM VdO -> dP + Tensor acc_DP = partition_fragment_C(tiled_mma_nm, take<0,2>(TileShapeNM{})); + + warpgroup_fence_operand(acc_DP); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_nm, tDPrV(_,_,_,smem_pipe_read_v.index()), tDPrDO(_,_,_,smem_pipe_read_do.index()), acc_DP); + warpgroup_commit_batch(); + + Tensor reg_LSE = make_fragment_like(acc_S); + for (int i = 0; i < size(reg_LSE); i++) { + reg_LSE(i) = ((ElementAccumulator)std::log2(std::exp(1.0))) * tSsLSE(_,_,_,smem_pipe_read_q.index())(i); + } + + Tensor reg_ODO = make_fragment_like(acc_S); + if constexpr (decltype(get<0>(TileShape{}) != _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_S); + + math_wg_order_barrier.wait(); + // Compute S -> P + Fusion{}.before_softmax(acc_S, tScS, problem_size); + auto acc_P = make_fragment_like(acc_S); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_P); i++) { + acc_P(i) = ::exp2f(params.scale_softmax_log2 * acc_S(i) - reg_LSE(i)); + } + math_wg_order_barrier.arrive(); + + if constexpr (decltype(get<0>(TileShape{}) == _128{})::value) { + for (int i = 0; i < size(reg_ODO); i++) { + reg_ODO(i) = tDPsODO(_,_,_,smem_pipe_read_do.index())(i); + } + } + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DP); + + // Compute dP P -> dS + auto acc_DS = make_fragment_like(acc_DP); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS); i++) { + // We could move the scale out and into the respective epilogues (or a final scaling step) + acc_DS(i) = acc_P(i) * params.scale_softmax * (acc_DP(i) - reg_ODO(i)); + } + + // GEMM PdO -> dV + auto op_P = make_acc_into_op(acc_P, typename TiledMmaND::LayoutA_TV{}); + warpgroup_fence_operand(acc_DV); + warpgroup_fence_operand(op_P); + warpgroup_arrive(); + cute::gemm(tiled_mma_nd, op_P, tDVrDO(_,_,_,smem_pipe_read_do.index()), acc_DV); + warpgroup_commit_batch(); + + // Store dS to smem dS' + if (wg_idx == 0) math_wg_order_barrier.wait(); + + auto recast_bits = [](auto sz, auto t) { + return recast>(t); + }; + auto tDPsDS_v = recast_bits(Int * 2>{}, tDPsDS); + auto acc_DS_v = recast_bits(Int * 2>{}, acc_DS); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DS_v); i++) { + tDPsDS_v(_,_,_,wg_idx)(i) = acc_DS_v(i); + } + + cutlass::arch::fence_view_async_shared(); + if (wg_idx == 0) math_wg_order_barrier.arrive(); + + // GEMM dS Q -> dK + if (wg_idx == 1) { + + math_wg_order_barrier.wait(); + + // GEMM dS' K -> dQ + Tensor acc_DQ = partition_fragment_C(tiled_mma_md, take<0,2>(TileShapeMD{})); + + warpgroup_fence_operand(acc_DQ); + warpgroup_arrive(); + gemm_zero_acc(tiled_mma_md, tDQrDS(_,_,_,0), tDQrK(_,_,_,smem_pipe_read_k_other.index()), acc_DQ); + cute::gemm(tiled_mma_md, tDQrDS(_,_,_,1), tDQrK(_,_,_,smem_pipe_read_k.index()), acc_DQ); + warpgroup_commit_batch(); + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DQ); + + math_wg_order_barrier.arrive(); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + auto tDQsDQ = tDQsDQ_full(_,_,_,smem_pipe_write_reducer.index()); + + // Store dQ to smem dQ' + // Invoke TMA reduce on dQ' + using Vec = uint_bit_t * 2>; + auto tDQsDQ_v = recast(tDQsDQ); + auto acc_DQ_v = recast(acc_DQ); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_DQ_v); i++) { + tDQsDQ_v(i) = acc_DQ_v(i); + } + + cutlass::arch::fence_view_async_shared(); + + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } else { + + warpgroup_fence_operand(acc_DK); + warpgroup_arrive(); + cute::gemm(TiledMmaND_SS{}, tDKrDSp(_,_,_,wg_idx), tDKrQ(_,_,_,smem_pipe_read_q.index()), acc_DK); + warpgroup_commit_batch(); + + warpgroup_wait<1>(); + warpgroup_fence_operand(acc_DK); + + pipeline_reducer.producer_acquire(smem_pipe_write_reducer); + pipeline_reducer.producer_commit(smem_pipe_write_reducer); + ++smem_pipe_write_reducer; + } + + --inner_tile_count; + + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + pipeline_inner.consumer_release(smem_pipe_release_inner); + ++smem_pipe_release_inner; + + tScS.data() = tScS.data() + E<1>{} * get<1>(TileShapeNM{}); + k_index += 1; + } + + pipeline_outer.consumer_release(smem_pipe_read_k); + pipeline_outer.consumer_release(smem_pipe_read_outer); + pipeline_reducer.producer_tail(smem_pipe_write_reducer); + ++smem_pipe_read_outer; + + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_DK); + warpgroup_fence_operand(acc_DV); + + return make_tuple(acc_DK, acc_DV); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp new file mode 100644 index 0000000000000000000000000000000000000000..55029faffbf5585ec8b41cd07781dc84ea6b7d6d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_load.hpp @@ -0,0 +1,140 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +enum class LoadKind { + kQ, kK, kV, + kBwdN, kBwdM, kBwdScalar +}; + +template< + LoadKind kKind, + class Pipeline, + class Element, + class SmemLayout, + class TMA +> +struct CollectiveLoadTma { + + using Params = TMA; + using SharedStorage = cute::array_aligned>; + using PipelineState = typename cutlass::PipelineState; + + Params const& params; + Pipeline& pipeline; + SharedStorage& storage; + + CUTLASS_DEVICE + CollectiveLoadTma(Params const& params, Pipeline& pipeline, SharedStorage& storage) + : params(params), pipeline(pipeline), storage(storage) {} + + template + CUTLASS_DEVICE auto init_g(ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& blk_coord, int loop_count + ) { + using X = Underscore; + if constexpr (kKind == LoadKind::kK) { + Tensor mK_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gK_full = local_tile(mK_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gK = gK_full(_, _, _, _0{}, get<2>(blk_coord)); + return gK; + } else if constexpr (kKind == LoadKind::kQ) { + Tensor mQ_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor gQ_full = local_tile(mQ_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor gQ = gQ_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(gQ.data() + loop_count * get<0>(blk_coord) * stride<2>(gQ), gQ.layout()); + } else if constexpr (kKind == LoadKind::kV) { + Tensor mV_full = params.get_tma_tensor(make_shape(get<4>(problem_size), get<3>(problem_size), select<0,1>(problem_size))); + Tensor gV_full = local_tile(mV_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor gV = gV_full(_, _, _0{}, _, get<2>(blk_coord)); + return gV; + } else if constexpr (kKind == LoadKind::kBwdN) { + Tensor m_full = params.get_tma_tensor(make_shape(get<3>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step<_1, X, _1>{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return make_tensor(g.data() + loop_count * get<1>(blk_coord) * stride<2>(g), g.layout()); + } else if constexpr (kKind == LoadKind::kBwdM) { + Tensor m_full = params.get_tma_tensor(make_shape(get<2>(problem_size), get<4>(problem_size), select<0,1>(problem_size))); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, _, _0{}, get<2>(blk_coord)); + return g; + } else if constexpr (kKind == LoadKind::kBwdScalar) { + Tensor m_full = params.get_tma_tensor(select<2,0,1>(problem_size)); + Tensor g_full = local_tile(m_full, tile_shape, make_coord(_, _, _), Step{}); + Tensor g = g_full(_, _, get<2,0>(blk_coord), get<2,1>(blk_coord)); + return g; + } + } + + template + CUTLASS_DEVICE auto init_state(ClusterRank const& block_rank_in_cluster, + ProblemSize const& problem_size, TileShape const& tile_shape, + BlockCoord const& block_coord, int loop_count + ) { + Tensor g = init_g(problem_size, tile_shape, block_coord, loop_count); + Tensor s = make_tensor(make_smem_ptr(storage.data()), SmemLayout{}); + + auto block_tma = params.get_slice(block_rank_in_cluster); + Tensor ts = block_tma.partition_D(s); + Tensor tg = block_tma.partition_S(g); + + return make_tuple(tg, ts); + } + + template + CUTLASS_DEVICE void step(TileIterator& tile_iter, State const& state, + PipelineState& smem_pipe_write, + int lane_predicate, int& tile_count, uint16_t mcast_mask = 0 + ) { + if ((lane_predicate == 1) && (tile_count > 0)) { + if constexpr (kAcquireBarrier) pipeline.producer_acquire(smem_pipe_write); + using BarrierType = typename Pipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + if constexpr (kKind == LoadKind::kBwdScalar) { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,*tile_iter), get<1>(state)(_,_,smem_pipe_write.index())); + } else { + copy(params.with(*tma_barrier, mcast_mask), get<0>(state)(_,_,_,*tile_iter), get<1>(state)(_,_,_,smem_pipe_write.index())); + } + if constexpr (kAdvancePipe) ++smem_pipe_write; + if constexpr (kAdvanceIterator) ++tile_iter; + } + --tile_count; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9c958da082bc69faf04bb1db00aa48fc93353a2c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_softmax.hpp @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template< + class ElementAccumulator, + class Fusion, + class Params +> +struct CollectiveSoftmax { + Params const& params; + CUTLASS_DEVICE CollectiveSoftmax(Params const& params) : params(params) {} + + using SumType = float; + using MaxType = ElementAccumulator; + + template + CUTLASS_DEVICE auto init(AccPV const& acc_pv, TiledMmaPV const& tiled_mma_pv) { + Tensor s_max = make_fragment_like(size<0>(layout_acc_mn(tiled_mma_pv, acc_pv.layout()))); + Tensor a_sum = make_fragment_like(s_max); + return make_tuple(s_max, a_sum); + } + + CUTLASS_DEVICE float overload_exp2(float f) { + return ::exp2f(f); + } + + CUTLASS_DEVICE cutlass::half_t overload_exp2(cutlass::half_t f) { + auto a = f.raw(); + decltype(a) d; + asm("ex2.approx.f16 %0, %1;" : "=h"(d) : "h"(a)); + return cutlass::half_t::bitcast(d); + } + + + CUTLASS_DEVICE float overload_max(float a, float b) { + return ::max(a, b); + } + + CUTLASS_DEVICE cutlass::half_t overload_max(cutlass::half_t a, cutlass::half_t b) { + return cutlass::half_t{__hmax_nan(a.to_half(), b.to_half())}; + } + + CUTLASS_DEVICE half overload_to_native(cutlass::half_t f) { + return f.to_half(); + } + + CUTLASS_DEVICE float overload_to_native(float f) { + return f; + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, ProblemShape const& problem_shape) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + // Linear reduction is faster for the first iteration + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = acc_qk_mn(i, 0); + } + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(acc_qk_mn); j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + a_sum(i) = SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + template + CUTLASS_DEVICE auto step_interleave_begin(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max(i) = overload_max(s_max(i), __shfl_xor_sync(uint32_t(-1), s_max(i), stride(reduction_target_qk) * j)); + } + } + }); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + float s_max_cur = s_max(i) == -INFINITY ? 0.0f : s_max(i); + float scale = ::exp2f((s_max_prev(i) - s_max_cur) * params.scale_softmax_log2); + a_sum(i) *= scale; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale; + } + } + } + + template + CUTLASS_DEVICE auto step_interleave_step(AccQK_MN& acc_qk_mn, State& state) { + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(acc_qk_mn); j++) { + float local_max = s_max(j) == -INFINITY ? 0.f : s_max(j); + float scale_max = params.scale_softmax_log2 * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<1>(acc_qk_mn); k++) { + acc_qk_mn(j, k) = ::exp2f(params.scale_softmax_log2 * acc_qk_mn(j, k) - scale_max); + a_sum(j) += acc_qk_mn(j, k); + } + } + } + + template + CUTLASS_DEVICE auto step(AccQK& acc_qk, TiledMmaQK const& tiled_mma_qk, CountQK const& count_qk, State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv, ProblemShape const& problem_shape) { + + if constexpr (kUseFusion) { + Fusion{}.before_softmax(acc_qk, count_qk, problem_shape); + } + + Tensor acc_qk_mn = make_tensor(acc_qk.data(), layout_acc_mn(tiled_mma_qk, acc_qk.layout())); + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + static_assert(size<0>(acc_qk_mn) == size<0>(acc_pv_mn)); + auto reduction_target_qk = reduction_target_n(tiled_mma_qk); + constexpr int red_rank = decltype(rank(reduction_target_qk))::value; + + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor s_max_prev = make_fragment_like(s_max); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_qk_mn); i++) { + s_max_prev(i) = s_max(i); + + // Linear reduction is faster here, as well + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + s_max(i) = overload_max(s_max(i), acc_qk_mn(i, j)); + } + // reduce max + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target_qk); j *= 2) { + s_max(i) = overload_max(s_max(i), MaxType{__shfl_xor_sync(uint32_t(-1), overload_to_native(s_max(i)), stride(reduction_target_qk) * j)}); + } + }); + + MaxType local_max = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + MaxType scale = static_cast(params.scale_softmax_log2); + MaxType scale_max = scale * local_max; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_qk_mn); j++) { + acc_qk_mn(i, j) = overload_exp2(scale * acc_qk_mn(i, j) - scale_max); + } + + MaxType s_max_cur = s_max(i) == static_cast(-INFINITY) ? static_cast(0) : s_max(i); + SumType scale_pv = overload_exp2((s_max_prev(i) - s_max_cur) * scale); + a_sum(i) *= scale_pv; + + using ElementPV = typename AccPV::value_type; + ElementPV scale_pv_ele = static_cast(scale_pv); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_pv_mn); j++) { + acc_pv_mn(i, j) *= scale_pv_ele; + } + a_sum(i) += SumType{reduce(acc_qk_mn(i, _), cute::plus{})}; + } + } + + + template + CUTLASS_DEVICE auto tail(State& state, AccPV& acc_pv, TiledMmaPV const& tiled_mma_pv) { + auto& s_max = get<0>(state); + auto& a_sum = get<1>(state); + + Tensor acc_pv_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + auto reduction_target = reduction_target_n(tiled_mma_pv); + constexpr int red_rank = decltype(rank(reduction_target))::value; + for_each(make_seq{}, [&](auto r) { + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < shape(reduction_target); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_pv_mn); i++) { + a_sum(i) = a_sum(i) + __shfl_xor_sync(uint32_t(-1), a_sum(i), stride(reduction_target) * j); + } + } + }); + + Tensor acc_mn = make_tensor(acc_pv.data(), layout_acc_mn(tiled_mma_pv, acc_pv.layout())); + + Tensor lse = make_fragment_like(a_sum); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(acc_mn); i++) { + float sum = a_sum(i); + float inv_sum = (sum == 0.f || sum != sum) ? 1.f : __frcp_rn(sum); + lse(i) = (sum == 0.f || sum != sum) ? INFINITY : s_max(i) * params.scale_softmax + __logf(sum); + float scale = params.rp_dropout * inv_sum; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(acc_mn); j++) { + acc_mn(i, j) *= scale; + } + } + + return lse; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3dfda63c7bdcbbcb7a711b328942101047b14cc --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma.hpp @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + typename Element_, + typename ElementAccumulator_, + typename TileShape_, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaMainloopTma { + + using Element = Element_; + using ElementAccumulator = ElementAccumulator_; + using TileShape = TileShape_; + + // Options + using kClusterM = find_option_t, Options...>; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape; + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = TileShape; + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using LayoutQKV = cute::tuple>; + using LayoutQ = LayoutQKV; + using LayoutK = LayoutQKV; + using LayoutV = LayoutQKV; + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulator, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulator, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulator; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + static const int MaxThreadsPerBlock = size(typename CollectiveMmaQK::TiledMma{}); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE auto + compute( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage) + { + int warp_idx = cutlass::canonical_warp_idx_sync(); + int thread_idx = threadIdx.x; + + PipelineState smem_pipe_release = smem_pipe_read; + [[maybe_unused]] PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, 1); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + // Set predicate for the lowest lane_id in the warp + int lane_predicate = cute::elect_one_sync(); + + // Issue TmaLoads (Prologue fetches) + if (warp_idx == 0) { + auto q_tile_iter = cute::make_coord_iterator(1); + int q_tile_count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + // Loop over K elems + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + + int k_tile_count_tma = 2 * fusion_tile_count; + + uint16_t mcast_mask_b = 0; + + if (warp_idx == 0 && lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < StageCount; i++) { + if (i % 2 == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } else { + load_v.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + } + } + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(blk_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_fragment_like(convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{}))); + + Tensor acc_qk_input = make_tensor(acc_qk_fixed.data(), acc_qk.layout()); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + acc_qk_input(i) = static_cast(acc_qk(i)); + } + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + // + // Advance the pipe + // + + // Advance consumer pipeline + ++smem_pipe_read; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.template step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,_0{}), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + softmax.step_interleave_begin(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + + pipeline.consumer_release(smem_pipe_release); + + ++smem_pipe_release; + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + auto layout_qk_input = convert_c_layout_to_a_layout(acc_qk.layout(), shape<1>(typename decltype(tiled_mma_pv)::LayoutA_TV{})); + + Tensor acc_qk_input = make_tensor(acc_qk.data(), layout_qk_input); + + static_assert(decltype(size<1>(layout_qk_input) == _1{})::value); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<2>(tOrV); i++) { + Tensor acc_qk_element = make_fragment_like(layout_qk_input(_, _0{}, _0{})); + Tensor acc_qk_element_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_element); + Tensor acc_qk_input_mk = tensor_op_mk_v(tiled_mma_pv, acc_qk_input(_, _0{}, i)); + softmax.step_interleave_step(acc_qk_input_mk, softmax_state); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size(acc_qk_element_mk); j++) { + acc_qk_element_mk(j) = static_cast(acc_qk_input_mk(j)); + } + warpgroup_arrive(); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<1>(tOrV); j++) { + cute::gemm(tiled_mma_pv, acc_qk_element, tOrV(_,j,i,smem_pipe_read.index()), acc_pv(_,_0{},j)); + } + } + warpgroup_commit_batch(); + + // Wait for the pipeline MMAs to drain + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + + if (warp_idx == 0) { + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count_tma, mcast_mask_b); + } + + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9fd45baf3806c7246a4c358273bc544ca941d022 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_collective_tma_warpspecialized.hpp @@ -0,0 +1,560 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "../collective/fmha_common.hpp" +#include "../collective/fmha_collective_load.hpp" +#include "../collective/fmha_collective_softmax.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; +using cutlass::fmha::kernel::Tag; +using cutlass::fmha::kernel::find_option_t; + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // SeqQ, SeqKV, Head + class LayoutQ_, class LayoutK_, class LayoutV_, // SeqX, Head, (Batches) + class Fusion, + class... Options +> +struct FmhaMainloopTmaWarpSpecialized { + + using Element = Element_; + using ElementAccumulatorQK = ElementAccumulatorQK_; + using ElementAccumulatorPV = ElementAccumulatorPV_; + using TileShape = TileShape_; + + using LayoutQ = LayoutQ_; + using LayoutK = LayoutK_; + using LayoutV = LayoutV_; + + // Options + static constexpr bool kIsPersistent = find_option_t::value; + static constexpr bool kIsMainloopLocked = find_option_t::value; + + static constexpr int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = find_option_t, Options...>::value; + static constexpr int StageCount = find_option_t, Options...>::value; + static constexpr int StageCountQ = find_option_t, Options...>::value; + + static const int kOuterLoads = 1; + using StagesQ = cutlass::gemm::collective::StageCount; + using Stages = cutlass::gemm::collective::StageCount; + using ClusterShape = Shape<_1, _1, _1>; + static_assert(StagesQ::value >= NumMmaWarpGroups); + static_assert(Stages::value >= 2); + + // 16B alignment lets us use TMA + static constexpr int Alignment = 16 / sizeof(Element); + + using TileShapeQK = Shape< + decltype(tuple_element_t<0, TileShape>{} / Int{}), + tuple_element_t<1, TileShape>, + tuple_element_t<2, TileShape>>; + + using TileShapePV = decltype(select<0,2,1>(TileShapeQK{})); + + using CollectiveMmaQK = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Element, LayoutQ, Alignment, + Element, LayoutK, Alignment, + ElementAccumulatorQK, + TileShapeQK, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using CollectiveMmaPV = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + // the stride for A does not matter since we do not load from smem at all + Element, LayoutK, Alignment, + Element, decltype(select<1,0,2>(LayoutV{})), Alignment, + ElementAccumulatorPV, + TileShapePV, ClusterShape, Stages, + cutlass::gemm::KernelTmaWarpSpecialized>::CollectiveOp; + + using TiledMmaQK = typename CollectiveMmaQK::TiledMma; + using TiledMmaPV = decltype(convert_to_gmma_rs(typename CollectiveMmaPV::TiledMma{})); + + using SmemLayoutQ = decltype(unstageSmemLayout(typename CollectiveMmaQK::SmemLayoutA{}, Int{})); + using SmemLayoutK = typename CollectiveMmaQK::SmemLayoutB; + using SmemLayoutV = typename CollectiveMmaPV::SmemLayoutB; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using MainloopPipelineQ = cutlass::PipelineTmaAsync; + + using PipelineState = typename cutlass::PipelineState; + using PipelineStateQ = typename cutlass::PipelineState; + + static constexpr int kInnerLoadBytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); + static constexpr int kOuterLoadBytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); + + using TileShapeOut = TileShapePV; + using TiledMmaOut = TiledMmaPV; + using ElementOut = ElementAccumulatorPV; + + struct SharedStorage { + cute::array_aligned> smem_q; + union { + cute::array_aligned> smem_k; + cute::array_aligned> smem_v; + }; + }; + + struct Arguments { + const Element* ptr_Q; + LayoutQ dQ; + const Element* ptr_K; + LayoutK dK; + const Element* ptr_V; + LayoutV dV; + }; + + using TMA_Q = typename CollectiveMmaQK::Params::TMA_A; + using TMA_K = typename CollectiveMmaQK::Params::TMA_B; + using TMA_V = typename CollectiveMmaPV::Params::TMA_B; + + struct Params { + TMA_Q tma_load_q; + TMA_K tma_load_k; + TMA_V tma_load_v; + + float scale_softmax; + float scale_softmax_log2; + float rp_dropout; + }; + + using LoadQ = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kQ, + MainloopPipelineQ, + Element, + SmemLayoutQ, + TMA_Q + >; + + using LoadK = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kK, + MainloopPipeline, + Element, + SmemLayoutK, + TMA_K + >; + + using LoadV = cutlass::fmha::collective::CollectiveLoadTma< + cutlass::fmha::collective::LoadKind::kV, + MainloopPipeline, + Element, + SmemLayoutV, + TMA_V + >; + + static_assert(size(typename CollectiveMmaQK::TiledMma{}) == size(typename CollectiveMmaPV::TiledMma{})); + + template + static bool can_implement(ProblemShape const& problem_size, Arguments const& args) { + return true + && (get<4>(problem_size) <= get<2>(TileShape{})) + && ((get<4>(problem_size) % Alignment) == 0) + && ((get<2>(problem_size) % Alignment) == 0) + ; + } + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace) { + + auto problem_shape_qk = make_shape(get<2>(problem_size), get<3>(problem_size), get<4>(problem_size), make_shape(get<0>(problem_size), get<1>(problem_size))); + auto params_qk = CollectiveMmaQK::to_underlying_arguments(problem_shape_qk, + typename CollectiveMmaQK::Arguments { + args.ptr_Q, args.dQ, + args.ptr_K, args.dK, + }, /*workspace=*/ nullptr); + + auto problem_shape_pv = select<0,2,1,3>(problem_shape_qk); + auto params_pv = CollectiveMmaPV::to_underlying_arguments(problem_shape_pv, + typename CollectiveMmaPV::Arguments { + args.ptr_K, args.dK, // never used, dummy + args.ptr_V, select<1,0,2>(args.dV), + }, /*workspace=*/ nullptr); + + return Params{ + params_qk.tma_load_a, + params_qk.tma_load_b, + params_pv.tma_load_b, + 1.0f / (float) std::sqrt(get<4>(problem_size)), + (float) (std::log2(std::exp(1.0)) / std::sqrt(get<4>(problem_size))), + 1.0f + }; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& params) { + cute::prefetch_tma_descriptor(params.tma_load_q.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_k.get_tma_descriptor()); + cute::prefetch_tma_descriptor(params.tma_load_v.get_tma_descriptor()); + } + + template + CUTLASS_DEVICE void + load_kv_maybe_q( + int block_rank_in_cluster, + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_write, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int fusion_tile_count = Fusion{}.get_trip_count(blk_coord, TileShape{}, problem_size); + + int lane_predicate = cute::elect_one_sync(); + + uint16_t mcast_mask_b = 0; + + if (lane_predicate == 1) { + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,_0{},Int<0>{})); + } + } + } + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + [[maybe_unused]] int q_tile_count = NumMmaWarpGroups; + + auto k_tile_iter = cute::make_coord_iterator(fusion_tile_count); + int k_tile_count = 2 * fusion_tile_count; + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + LoadK load_k{params.tma_load_k, pipeline, storage.smem_k}; + auto load_state_k = load_k.init_state(block_rank_in_cluster, problem_size, TileShapeQK{}, blk_coord, fusion_tile_count); + + LoadV load_v{params.tma_load_v, pipeline, storage.smem_v}; + auto load_state_v = load_v.init_state(block_rank_in_cluster, problem_size, TileShapePV{}, blk_coord, fusion_tile_count); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + + if constexpr (! kLoadQ) { + if (do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + + if constexpr (kLoadQ) { + while (q_tile_count > 0) { + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, q_tile_count); + } + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) { + load_k.template step(k_tile_iter, load_state_k, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + load_v.template step(k_tile_iter, load_state_v, smem_pipe_write, lane_predicate, k_tile_count, mcast_mask_b); + } + } + + template + CUTLASS_DEVICE void + load_maybe_q( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_write_q, + SharedStorage& storage, + LoadWarpBarrier& load_warp_barrier, bool do_barrier) + { + int lane_predicate = cute::elect_one_sync(); + + LoadQ load_q{params.tma_load_q, pipeline_q, storage.smem_q}; + auto load_state_q = load_q.init_state(_0{}, problem_size, TileShapeQK{}, blk_coord, NumMmaWarpGroups); + + auto q_tile_iter = cute::make_coord_iterator(Int{}); + + CUTLASS_PRAGMA_UNROLL + for (int q_tile_count = 0; q_tile_count < NumMmaWarpGroups; q_tile_count++) { + int count = 1; + load_q.step(q_tile_iter, load_state_q, smem_pipe_write_q, lane_predicate, count); + if (q_tile_count == 0 && do_barrier) { + load_warp_barrier.arrive(); + load_warp_barrier.wait(/*phase=*/ 0); + do_barrier = false; + } + } + } + + template + CUTLASS_DEVICE void + reduce( + BlkCoord const& blk_coord, Params const& params, ProblemShape const& problem_size, + MainloopPipelineReducer& pipeline_reducer, PipelineStateReducer& smem_pipe_write_reducer, + SharedStorage& storage) + { /* no-op */ } + + template + CUTLASS_DEVICE auto + compute( + BlkCoord const& blk_coord, BlkCoord const& wg_coord, + Params const& params, ProblemShape const& problem_size, + MainloopPipeline& pipeline, PipelineState& smem_pipe_read, + MainloopPipelineQ& pipeline_q, PipelineStateQ& smem_pipe_read_q, + MainloopPipelineReducer&, PipelineStateReducer&, + SharedStorage& storage, + MathWgOrderBarrier& math_wg_order_barrier) + { + int thread_idx = int(threadIdx.x); + + PipelineState smem_pipe_release = smem_pipe_read; + PipelineStateQ smem_pipe_release_q = smem_pipe_read_q; + + TiledMmaQK tiled_mma_qk; + auto thr_mma_qk = tiled_mma_qk.get_thread_slice(thread_idx); + + // Mainloop setup QK + Tensor sQ = make_tensor(make_smem_ptr(storage.smem_q.data()), SmemLayoutQ{}); + Tensor sK = make_tensor(make_smem_ptr(storage.smem_k.data()), SmemLayoutK{}); + + Tensor tSsQ = thr_mma_qk.partition_A(sQ); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tSsK = thr_mma_qk.partition_B(sK); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrQ = thr_mma_qk.make_fragment_A(tSsQ); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tSrK = thr_mma_qk.make_fragment_B(tSsK); // (MMA,MMA_M,MMA_N,PIPE) + + // Prepare: MMA PV + TiledMmaPV tiled_mma_pv; + auto thr_mma_pv = tiled_mma_pv.get_thread_slice(thread_idx); + + // Mainloop setup PV + Tensor sV = make_tensor(make_smem_ptr(storage.smem_v.data()), SmemLayoutV{}); + + Tensor tOsV = thr_mma_pv.partition_B(sV); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tOrV = thr_mma_pv.make_fragment_B(tOsV); // (MMA,MMA_M,MMA_N,PIPE) + + int k_tile_count = Fusion{}.get_unmasked_trip_count(blk_coord, TileShape{}, problem_size); + + pipeline_q.consumer_wait(smem_pipe_read_q); + + // mapping into QK accumulator + Tensor cP = make_identity_tensor(take<0,2>(TileShapeQK{})); + Tensor tPcP = thr_mma_qk.partition_C(cP); + int m_block = get<0>(wg_coord); + tPcP.data() = tPcP.data() + E<0>{} * m_block * get<0>(TileShapeQK{}); + + // Allocate PV acc + Tensor acc_pv = partition_fragment_C(tiled_mma_pv, take<0, 2>(TileShapePV{})); + + cutlass::fmha::collective::CollectiveSoftmax softmax{params}; + auto softmax_state = softmax.init(acc_pv, tiled_mma_pv); + + if (true) + { + --k_tile_count; + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + math_wg_order_barrier.wait(); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + math_wg_order_barrier.arrive(); + + ++smem_pipe_read; + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, problem_size); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + // Advance consumer pipeline + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.template step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + k_tile_count += Fusion{}.get_masked_trip_count(blk_coord, TileShape{}, problem_size); + + CUTLASS_PRAGMA_NO_UNROLL + while (k_tile_count > 0) + { + --k_tile_count; + + // Allocate QK acc + Tensor acc_qk = partition_fragment_C(tiled_mma_qk, take<0, 2>(TileShapeQK{})); + + pipeline.consumer_wait(smem_pipe_read); + + // MMA QK + warpgroup_fence_operand(acc_qk); + warpgroup_arrive(); + + gemm_zero_acc(tiled_mma_qk, tSrQ(_,_,_,smem_pipe_read_q.index()), tSrK(_,_,_,smem_pipe_read.index()), acc_qk); + warpgroup_commit_batch(); + + ++smem_pipe_read; + auto tok = pipeline.consumer_try_wait(smem_pipe_read); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_qk); + warpgroup_fence_operand(acc_pv); + + //if constexpr (kIsPersistent) + // if (k_tile_count == 0) pipeline_q.consumer_release(smem_pipe_release_q); + + if constexpr (kIsMainloopLocked) math_wg_order_barrier.wait(); + softmax.step(acc_qk, tiled_mma_qk, tPcP, softmax_state, acc_pv, tiled_mma_pv, problem_size); + if constexpr (kIsMainloopLocked) math_wg_order_barrier.arrive(); + + Tensor acc_qk_fixed = make_acc_into_op(acc_qk, typename TiledMmaPV::LayoutA_TV{}); + + pipeline.consumer_wait(smem_pipe_read, tok); + + // MMA PV + warpgroup_fence_operand(acc_pv); + warpgroup_fence_operand(acc_qk_fixed); + warpgroup_arrive(); + + cute::gemm(tiled_mma_pv, acc_qk_fixed, tOrV(_,_,_,smem_pipe_read.index()), acc_pv); + warpgroup_commit_batch(); + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + ++smem_pipe_read; + tPcP.data() = tPcP.data() + E<1>{} * get<1>(TileShapeQK{}); + } + + if (kIsPersistent) pipeline_q.consumer_release(smem_pipe_release_q); + + // Wait for the pipeline MMAs to drain + warpgroup_wait<0>(); + warpgroup_fence_operand(acc_pv); + + if (kIsPersistent) pipeline.consumer_release(smem_pipe_release); + ++smem_pipe_release; + + Tensor lse = softmax.tail(softmax_state, acc_pv, tiled_mma_pv); + + return make_tuple(acc_pv, lse); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp new file mode 100644 index 0000000000000000000000000000000000000000..40c23f48bbebb9296292149e193e6e8bfa2e3752 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_common.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/kernel_hardware_info.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +template +CUTE_DEVICE void gemm_reset_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + constexpr int rA = decltype(rank(tA))::value; + constexpr int rB = decltype(rank(tB))::value; + constexpr int rC = decltype(rank(tC))::value; + if constexpr (rA == 2 && rB == 2 && rC == 1) { + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<1>(tA); k_block++) { + cute::gemm(atom, tA(_,k_block), tB(_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } else { + static_assert(rA == 3 && rB == 3 && rC == 3); + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tA); k_block++) { + cute::gemm(atom, tA(_,_,k_block), tB(_,_,k_block), tC); + atom.accumulate_ = GMMA::ScaleOut::One; + } + } +} + +template +CUTE_DEVICE void gemm_zero_acc(Atom& atom, TA const& tA, TB const& tB, TC&& tC) { + atom.accumulate_ = GMMA::ScaleOut::Zero; + gemm_reset_zero_acc(atom, tA, tB, tC); +} + +template +CUTE_DEVICE constexpr typename T::value_type reduce(T const& t, Fn fn) { + if constexpr (decltype(size(t) % _2{} == _0{})::value) { + auto partial = make_tensor(size(t) / _2{}); + CUTE_UNROLL + for (int i = 0; i < size(partial); i++) { + partial(i) = fn(t(i), t(i + size(partial))); + } + return reduce(partial, fn); + } else { + auto result = t(_0{}); + CUTE_UNROLL + for (int i = 1; i < size(t); i++) { + result = fn(result, t(i)); + } + return result; + } +} + +struct fmha_max { + CUTE_DEVICE float operator()(float a, float b) { return ::max(a, b); } +}; + +template +inline auto __device__ constexpr layout_separate(Threshold const& thr, + Source const& src, Reference const& ref) { + auto lt = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r < thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + auto ge = filter(transform_layout(src, ref, [&](auto const& s, auto const& r) { + if constexpr(decltype(r >= thr)::value) { + return s; + } else { + return make_layout(_1{}, _0{}); + } + })); + return make_layout(lt, ge); +} + +template +inline auto __device__ constexpr layout_acc_mn(TiledMma const& tiled_mma, Acc const& acc) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutC_TV{})); + auto V_M = get<0>(separated); + auto V_N = get<1>(separated); + return make_layout(make_layout(V_M, get<1>(acc)), make_layout(V_N, get<2>(acc))); +} + +template +inline auto __device__ constexpr layout_op_mk_v(TiledMma const& tiled_mma, Acc const& acc) { + return layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + get<0>(acc), stride<1>(typename TiledMma::LayoutA_TV{})); +} + +template +inline auto __device__ constexpr tensor_op_mk_v(TiledMma const& tiled_mma, Acc&& acc) { + return make_tensor(acc.data(), layout_op_mk_v(tiled_mma, acc.layout())); +} + +template +inline auto __device__ constexpr reduction_target_n(TiledMma const& tiled_mma) { + auto separated = layout_separate(get<0>(typename TiledMma::Shape_MNK{}), + make_layout(shape<0>(typename TiledMma::LayoutC_TV{})), + stride<0>(typename TiledMma::LayoutC_TV{})); + return get<1>(separated); +} + + +template class Primitive, cute::GMMA::Major tA, cute::GMMA::Major tB, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template class Primitive, cute::GMMA::ScaleIn sA, cute::GMMA::ScaleIn sB> +inline auto __device__ constexpr convert_to_gmma_rs(cute::MMA_Atom> const& tiled_mma) { + using Atom = cute::MMA_Atom>; + using ElementA = typename Atom::ValTypeA; + using ElementB = typename Atom::ValTypeB; + using ElementC = typename Atom::ValTypeC; + using Shape_MNK = typename Atom::Shape_MNK; + constexpr auto tA = cute::GMMA::Major::K; + constexpr auto tB = cute::GMMA::Major::K; + using RS = decltype(cute::GMMA::rs_op_selector()); + return cute::MMA_Atom{}; +} + +template +CUTE_DEVICE auto constexpr convert_to_gmma_rs(cute::TiledMMA const& tiled_mma) { + return cute::TiledMMA{}; +} + +template +CUTE_DEVICE auto constexpr convert_c_layout_to_a_layout(CLayout const& c, AValueShape const& a) { + return make_layout( + make_shape(a, shape<1>(c), make_shape(shape<2>(c), size<0>(c) / size(a))), + make_stride(stride<0>(c), stride<1>(c), make_stride(stride<2>(c), size<2>(a) * stride<0,2>(c)))); +} + +template +CUTE_DEVICE constexpr auto unstageSmemLayout(Layout const& layout, Stages stages = {}) { + return composition(layout, make_tuple(_, _, make_layout(stages))); +} + +template +CUTE_DEVICE auto make_acc_into_op(Accumulator const& acc, OperandLayout_TV const& operand_layout_tv) { + Tensor operand = make_fragment_like(convert_c_layout_to_a_layout(acc.layout(), shape<1>(operand_layout_tv))); + Tensor operand_as_acc = make_tensor(operand.data(), acc.layout()); + + cute::copy(acc, operand_as_acc); + + if constexpr (sizeof(Element) == 1) { + + // 00 11 22 33 00 11 22 33 acc layout + // 00 00 11 11 22 22 33 33 operand layout + // BB AA AA BB AA BB BB AA conflict-free exchange pattern + // 16-bit exchange; so process two at a time potentially + int tid = threadIdx.x % 4; + auto values_u32 = recast(operand); + + CUTE_UNROLL + for (int n = 0; n < size<1>(values_u32); n++) { + CUTE_UNROLL + for (int k = 0; k < size<2>(values_u32); k++) { + CUTE_UNROLL + for (int ii = 0; ii < 8; ii += 4) { + + uint32_t values_tmp_0 = values_u32(ii / 2 + 0, n, k); + uint32_t values_tmp_1 = values_u32(ii / 2 + 1, n, k); + + // step A: + // t 1 v 0 -> t 0 v 1 + // t 2 v 0 -> t 1 v 0 + // t 0 v 1 -> t 2 v 0 + // t 3 v 1 -> t 3 v 1 + + int v_to_send = tid == 1 || tid == 2 ? 0 : 1; + int v_to_recv = v_to_send; + int t_to_recv_from = (0x3021 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_a = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_a = __shfl_sync(0xFFFFFFFF, values_tmp_a, t_to_recv_from, 4); + + // step B: + // t 0 v 0 -> t 0 v 0 + // t 3 v 0 -> t 1 v 1 + // t 1 v 1 -> t 2 v 1 + // t 2 v 1 -> t 3 v 0 + + v_to_send = 1 - v_to_send; + v_to_recv = 1 - v_to_recv; + t_to_recv_from = (0x2130 >> (tid * 4)) & 0xF; + + uint32_t values_tmp_b = v_to_send == 0 ? values_tmp_0 : values_tmp_1; + + values_tmp_b = __shfl_sync(0xFFFFFFFF, values_tmp_b, t_to_recv_from, 4); + + values_u32(ii / 2 + 0, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x1054 : 0x5410); + values_u32(ii / 2 + 1, n, k) = __byte_perm(values_tmp_a, values_tmp_b, v_to_send == 0 ? 0x3276 : 0x7632); + } + } + } + } + + return operand; +} + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp new file mode 100644 index 0000000000000000000000000000000000000000..246a625ff7bf3591558e11cb42d53e64750a865a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../collective/fmha_common.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaFwdEpilogue { + + static constexpr int Alignment = 16 / sizeof(Element); + + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Arguments { + Element* ptr_O; + cute::tuple> dO; + + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + }; + + struct Params { + ElementAccumulator* ptr_LSE; + cute::tuple> dLSE; + + typename CollectiveEpilogueTMA::Params epilogue_TMA; + }; + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_tma{{}, args.ptr_O, args.dO, args.ptr_O, args.dO}; + return Params{ + args.ptr_LSE, args.dLSE, + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_o, args_tma, workspace) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, + TensorStorage& epi_tensor_storage) + { + using X = Underscore; + + auto acc = get<0>(result); + auto lse = get<1>(result); + + auto thr_mma = tiled_mma.get_thread_slice(threadIdx.x); + + int seqlen_q = get<2>(problem_size); + int num_batch = get<0>(problem_size); + int num_heads = get<1>(problem_size); + // Epilogue for lse + Tensor mLSE = make_tensor(make_gmem_ptr(params.ptr_LSE), + make_shape(seqlen_q, get<1>(tile_shape), make_shape(num_batch, num_heads)), + make_stride(_1{}, _0{}, get<1>(params.dLSE))); + Tensor gLSE_full = local_tile(mLSE, tile_shape, make_coord(_, _, _), Step<_1, _1, X>{}); + Tensor gLSE = gLSE_full(_, _, get<0>(blk_coord), get<1>(blk_coord), get<2>(blk_coord)); + Tensor tOgLSE = thr_mma.partition_C(gLSE); + Tensor cO = make_identity_tensor(take<0,2>(tile_shape)); + Tensor tOcO = thr_mma.partition_C(cO); + if (get<1>(tOcO(_0{})) == 0) { + auto tOgLSE_mn = make_tensor(tOgLSE.data(), layout_acc_mn(tiled_mma, tOgLSE.layout())); + auto tOcO_mn = make_tensor(tOcO.data(), layout_acc_mn(tiled_mma, tOcO.layout())); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tOgLSE_mn); i++) { + if (get<0>(tOcO_mn(i)) + get<0>(blk_coord) * get<0>(tile_shape) < get<2>(problem_size)) { + tOgLSE_mn(i, _0{}) = lse(i); + } + } + } + auto problem_size_o = make_shape(get<2>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + CollectiveEpilogueTMA epilogue_tma(params.epilogue_TMA, epi_tensor_storage); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_tma.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_o, tile_shape, make_coord(get<0>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage + ); + + epilogue_tma.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..677722529f5ea59e5a7b1f37823b9ad96506f2a6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_epilogue_bwd.hpp @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "../collective/fmha_epilogue.hpp" + +namespace cutlass::fmha::collective { + +template +struct FmhaBwdEpilogueKV { + + static constexpr int Alignment = 16 / sizeof(Element); + + struct Arguments { + Element* ptr_K; + cute::tuple dK; + + Element* ptr_V; + cute::tuple dV; + }; + + //using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using DefaultOperation = cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute, + cutlass::epilogue::fusion::Sm90AccFetch + >; + using CollectiveEpilogueTMA = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_WG, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + void, cute::tuple>, Alignment, + Element, cute::tuple>, Alignment, + cutlass::epilogue::TmaWarpSpecialized, + DefaultOperation + >::CollectiveOp; + + struct Params { + typename CollectiveEpilogueTMA::Params epilogue_K; + typename CollectiveEpilogueTMA::Params epilogue_V; + }; + + + using TensorStorage = typename CollectiveEpilogueTMA::TensorStorage[2]; + using PipelineStorage = typename CollectiveEpilogueTMA::PipelineStorage; + using LoadPipeline = typename CollectiveEpilogueTMA::LoadPipeline; + static constexpr int TmaTransactionBytes = CollectiveEpilogueTMA::TmaTransactionBytes; + + template + static Params to_underlying_arguments(ProblemShape const& problem_size, Arguments const& args, void* workspace = nullptr) { + auto dK = make_stride(get<2>(args.dK), get<3>(args.dK), + make_stride(get<0>(args.dK), get<1>(args.dK))); + auto dV = make_stride(get<2>(args.dV), get<3>(args.dV), + make_stride(get<0>(args.dV), get<1>(args.dV))); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), 1, + make_shape(get<0>(problem_size), get<1>(problem_size))); + typename CollectiveEpilogueTMA::Arguments args_k{{}, args.ptr_K, dK, args.ptr_K, dK}; + typename CollectiveEpilogueTMA::Arguments args_v{{}, args.ptr_V, dV, args.ptr_V, dV}; + return Params{ + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_k, nullptr), + CollectiveEpilogueTMA::to_underlying_arguments(problem_size_kv, args_v, nullptr) + }; + } + + template + CUTLASS_DEVICE void operator()( + TileShape const& tile_shape, BlkCoord const& blk_coord, + ResultTuple const& result, TiledMma const& tiled_mma, + ProblemShape const& problem_size, Params const& params, + LoadPipeline epi_load_pipeline, TensorStorage& epi_tensor_storage) + { + auto acc_k = get<0>(result); + auto acc_v = get<1>(result); + + auto problem_size_kv = make_shape(get<3>(problem_size), get<4>(problem_size), _, + make_shape(get<0>(problem_size), get<1>(problem_size))); + + using EpiStorePipeline = typename CollectiveEpilogueTMA::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename CollectiveEpilogueTMA::LoadPipelineState epi_load_pipe_consumer_state; + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + CollectiveEpilogueTMA epilogue_k{params.epilogue_K, epi_tensor_storage[0]}; + CollectiveEpilogueTMA epilogue_v{params.epilogue_V, epi_tensor_storage[1]}; + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_k.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_k, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[0] + ); + + } + + { + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + epilogue_v.store( + epi_load_pipeline, epi_load_pipe_consumer_state, + epi_store_pipeline, epi_store_pipe_producer_state, + problem_size_kv, tile_shape, make_coord(get<1>(blk_coord), _0{}, _, get<2>(blk_coord)), + acc_v, tiled_mma, threadIdx.x % cutlass::NumThreadsPerWarpGroup, + epi_tensor_storage[1] + ); + + epilogue_k.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + + epilogue_v.store_tail( + epi_load_pipeline, epi_load_pipe_consumer_state_next, + epi_store_pipeline, epi_store_pipe_producer_state_next + ); + } + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac51f250763e6890673d57d8238f0be464f70e77 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/collective/fmha_fusion.hpp @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" + +namespace cutlass::fmha::collective { + +using namespace cute; + +struct DefaultFusion { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<3>(problem_size), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 0; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + return; + } +}; + +struct ResidualFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return 1; + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - 1; + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // This is useful is seqlen_k % kBlockN != 0 since it masks + // the remaining elements out from softmax. + // d % kHeadDim != 0 or seqlen_q % kBlockM do not suffer from similar + // issues as they are transparently taken care of by TMA and the + // epilogue, if it is instantiated with predication support. + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) >= get<3>(problem_size)) { + acc_qk(i) = -INFINITY; + } + } + } +}; + +struct CausalFusion : DefaultFusion { + + using Base = DefaultFusion; + + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + // See note below on different ways to think about causal attention + // Again, we'd add the offset_q into the max_blocks_q calculation + int max_blocks_k = Base::get_trip_count(blk_coord, tile_shape, problem_size); + int max_blocks_q = ceil_div((get<0>(blk_coord) + 1) * get<0>(tile_shape), get<1>(tile_shape)); + return std::min(max_blocks_k, max_blocks_q); + } + + template + CUTLASS_DEVICE + int get_masked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return ceil_div(get<0>(tile_shape), get<1>(tile_shape)); + } + + template + CUTLASS_DEVICE + int get_unmasked_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get_trip_count(blk_coord, tile_shape, problem_size) - get_masked_trip_count(blk_coord, tile_shape, problem_size); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + // There are two ways to do causal if N_Q != N_K + // (1) is to assume that the Q is at the beginning of the matrix + // - this is what we demonstrate here + // (2) is that it is at the end of the matrix + // - this is usually what we want for inference settings + // where we only compute the next row and use cache for the rest + // - if you'd like this, you only need to add an offset like so: + // get<0>(pos) + offset_q < get<1>(pos) + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<0>(pos) < get<1>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + +}; + +template +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return Base{}.get_trip_count(select<1,0,2>(blk_coord), select<1,0,2>(tile_shape), select<0,1,3,2,4>(problem_size)); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + ) { + auto index_base = index_qk(_0{}); + auto index_shape = shape(index_qk); + auto index_stride = transform_leaf(stride(index_qk), [](auto elem) { + if constexpr (is_scaled_basis::value) { + if constexpr(decltype(elem.mode() == _0{})::value) { + return ScaledBasis(elem.value()); + } else { + return ScaledBasis(elem.value()); + } + } else { + return elem; + } + }); + auto index_qk_bwd = make_tensor(make_inttuple_iter(select<1,0>(index_base)), make_layout(index_shape, index_stride)); + Base{}.before_softmax(acc_qk, index_qk_bwd, problem_size); + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return true; + } +}; + +template<> +struct FusionBwdAdapter { + template + CUTLASS_DEVICE + int get_trip_count( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + return get<2>(problem_size) / get<0>(TileShape{}); + } + + template + CUTLASS_DEVICE + void before_softmax( + AccQK& acc_qk, + IndexQK const& index_qk, + ProblemSize const& problem_size + + ) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(acc_qk); i++) { + auto pos = index_qk(i); + if (get<1>(pos) < get<0>(pos)) { + acc_qk(i) = -INFINITY; + } + } + } + + template + CUTLASS_DEVICE + bool is_contributing( + BlkCoord const& blk_coord, + TileShape const& tile_shape, + ProblemSize const& problem_size + ) { + int max_q = get<0>(blk_coord) * get<0>(tile_shape) + get<0>(tile_shape); + int min_k = get<1>(blk_coord) * get<1>(tile_shape); + return min_k <= max_q; + } +}; + +} // namespace cutlass::fmha::collective diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp new file mode 100644 index 0000000000000000000000000000000000000000..beadb688730bd3a7936a7885ab72dd60b903c326 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/device_universal.hpp @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +#pragma once + +// common +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" + +#if !defined(__CUDACC_RTC__) +#include "cutlass/cluster_launch.hpp" +#include "cutlass/trace.h" +#endif // !defined(__CUDACC_RTC__) + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class Universal { +public: + using Kernel = Kernel_; + + static int const kThreadCount = Kernel::MaxThreadsPerBlock; + + /// Argument structure: User API + using Arguments = typename Kernel::Arguments; + /// Argument structure: Kernel API + using Params = typename Kernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + + bool is_initialized(bool set = false) { + static bool initialized = false; + if (set) initialized = true; + return initialized; + } + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (Kernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + workspace_bytes += Kernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return Kernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("Universal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = Kernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + Kernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes GEMM state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = Kernel::initialize_workspace(args, workspace, stream); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = Kernel::to_underlying_arguments(args, workspace); + + if (is_initialized()) return Status::kSuccess; + + // account for dynamic smem capacity if needed + int smem_size = Kernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + + is_initialized(true); + + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("Universal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = Kernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::run()"); + dim3 const block = Kernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = Kernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) { + dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}), + cute::size<1>(typename Kernel::ClusterShape{}), + cute::size<2>(typename Kernel::ClusterShape{})); + void const* kernel = (void const*) device_kernel; + void* kernel_params[] = {¶ms}; + launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + else { + launch_result = Status::kSuccess; + cutlass::arch::synclog_setup(); + device_kernel<<>>(params); + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return run(args, workspace, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp new file mode 100644 index 0000000000000000000000000000000000000000..47c43b716b1e318f8504eb4220a66692a6fb1106 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/device/fmha_device_bwd.hpp @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +/*! + \file + \brief An universal device layer for cutlass 3.x-style kernels. +*/ + +// common +#include "cutlass/cutlass.h" + +#include "../device/device_universal.hpp" +#include "../collective/fmha_collective_bwd_tma_warpspecialized.hpp" +#include "../collective/fmha_fusion.hpp" +#include "../collective/fmha_epilogue_bwd.hpp" +#include "../kernel/fmha_kernel_bwd_sum_OdO.hpp" +#include "../kernel/fmha_kernel_bwd_convert.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_tile_scheduler.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::fmha::device { + +//////////////////////////////////////////////////////////////////////////////// +////////////////////////////// CUTLASS 3.x API ///////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +template +class FmhaBwd { +public: + /// Argument structure: User API + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_Q; + cute::tuple stride_Q; + const Element* ptr_K; + cute::tuple stride_K; + const Element* ptr_V; + cute::tuple stride_V; + + const Element* ptr_O; + cute::tuple stride_O; + const ElementAccumulator* ptr_LSE; + cute::tuple stride_LSE; + + const Element* ptr_dO; + cute::tuple stride_dO; + + Element* ptr_dQ; + cute::tuple stride_dQ; + Element* ptr_dK; + cute::tuple stride_dK; + Element* ptr_dV; + cute::tuple stride_dV; + + cutlass::KernelHardwareInfo hw_info; + }; + + using OperationSumOdO = cutlass::device::Universal>; + using OperationConvert = cutlass::device::Universal>; + + using Mainloop = cutlass::fmha::collective::FmhaBwdMainloopTmaWarpSpecialized< + Element, ElementAccumulator, TileShape, + cutlass::fmha::collective::FusionBwdAdapter, Options...>; + + using Epilogue = cutlass::fmha::collective::FmhaBwdEpilogueKV; + + using Operation = cutlass::device::Universal< + cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized< + Mainloop, + Epilogue, + cutlass::fmha::kernel::TileSchedulerBwdAdapter, Options...>>; + + struct Params { + OperationSumOdO op_sum_OdO; + Operation op; + OperationConvert op_convert; + ElementAccumulator* dQ_acc; + size_t dQ_acc_size; + }; + +private: + Params params_; + + static typename OperationSumOdO::Arguments to_sum_OdO_arguments(Arguments const& args, ElementAccumulator* dest = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_sum_OdO = make_stride(H*Q, Q, _1{}); + return typename OperationSumOdO::Arguments { + args.problem_size, + args.ptr_O, args.stride_O, + args.ptr_dO, args.stride_dO, + dest, stride_sum_OdO + }; + } + + static typename OperationConvert::Arguments to_convert_arguments(Arguments const& args, ElementAccumulator* src = nullptr) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + auto stride_src_dQ = make_stride(B == 1 ? 0 : (H*Q*D), Q*D, D, _1{}); + return typename OperationConvert::Arguments { + args.problem_size, + src, stride_src_dQ, + nullptr, stride_src_dQ, + nullptr, stride_src_dQ, + args.ptr_dQ, args.stride_dQ, + nullptr, args.stride_dK, + nullptr, args.stride_dV + }; + } + + static typename Operation::Arguments to_bwd_arguments( + Arguments const& args, + ElementAccumulator* sum_OdO = nullptr, cute::tuple const& stride_sum_OdO = {}, + ElementAccumulator* dQ_acc = nullptr, cute::tuple const& stride_dQ = {} + ) { + return typename Operation::Arguments{ + args.problem_size, + { args.ptr_Q, args.stride_Q, + args.ptr_K, args.stride_K, + args.ptr_V, args.stride_V, + args.ptr_dO, args.stride_dO, + args.ptr_LSE, args.stride_LSE, + sum_OdO, stride_sum_OdO, + dQ_acc, stride_dQ }, + { args.ptr_dK, args.stride_dK, + args.ptr_dV, args.stride_dV }, + args.hw_info + }; + } + +public: + + /// Determines whether the GEMM can execute the given problem. + static Status + can_implement(Arguments const& args) { + Status status = Status::kSuccess; + + status = OperationSumOdO::can_implement(to_sum_OdO_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = OperationConvert::can_implement(to_convert_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + status = Operation::can_implement(to_bwd_arguments(args)); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + size_t workspace_bytes = 0; + // OdO vector + workspace_bytes += B*H*Q * sizeof(ElementAccumulator); + // FP32 versions of outputs that are churned (start off with Q only) + workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); + return workspace_bytes; + } + + /// Initializes state from arguments. + Status + initialize_split(Arguments const& args, void* workspace_dQ, void* workspace_sum_OdO, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize_split() - workspace_dQ=" + << workspace_dQ << ", workspace_sum_OdO=" << workspace_sum_OdO << "stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_sum_OdO); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_dQ); + params_.dQ_acc = dQ_acc; + params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); + auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO); + auto args_convert = to_convert_arguments(args, dQ_acc); + params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); + params_.op_convert.initialize(args_convert, nullptr, stream); + auto args_bwd = to_bwd_arguments(args, sum_OdO, args_sum_OdO.stride_sum_OdO, dQ_acc, args_convert.stride_src_dQ); + params_.op.initialize(args_bwd, nullptr, stream); + + return Status::kSuccess; + } + + /// Initializes state from arguments. + Status + initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("Universal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + auto [B, H, Q, K, D] = args.problem_size; + D = cutlass::round_up(D, 8); // Alignment + Q = cutlass::round_up(Q, 8); // Alignment + char* workspace_chr = reinterpret_cast(workspace); + ElementAccumulator* sum_OdO = reinterpret_cast(workspace_chr); + workspace_chr += B*H*Q * sizeof(ElementAccumulator); + ElementAccumulator* dQ_acc = reinterpret_cast(workspace_chr); + return initialize_split(args, dQ_acc, sum_OdO, stream); + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling Kernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr) { + CUTLASS_TRACE_HOST("FmhaDeviceBwd::run()"); + + Status result = Status::kSuccess; + result = params.op_sum_OdO.run(stream); + if (result != Status::kSuccess) { + return result; + } + + auto cuda_result = cudaMemsetAsync(params.dQ_acc, 0, params.dQ_acc_size, stream); + if (cuda_result != cudaSuccess) { + return Status::kErrorInternal; + } + result = params.op.run(stream); + if (result != Status::kSuccess) { + return result; + } + + result = params.op_convert.run(stream); + if (result != Status::kSuccess) { + return result; + } + + return Status::kSuccess; + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + Status status = initialize(args, workspace, stream); + if (Status::kSuccess == status) { + status = run(params_, stream); + } + return status; + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp new file mode 100644 index 0000000000000000000000000000000000000000..bd0e0f9d51620e90545f390cd397a15b0e7873ed --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_builder.hpp @@ -0,0 +1,158 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "../collective/fmha_collective_tma.hpp" +#include "../collective/fmha_collective_tma_warpspecialized.hpp" +#include "../collective/fmha_epilogue.hpp" +#include "../kernel/fmha_kernel_tma.hpp" +#include "../kernel/fmha_kernel_tma_warpspecialized.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class Element_, + class ElementAccumulatorQK_, + class ElementAccumulatorPV_, + class TileShape_, // BlockQO, BlockKV, BlockHead + class LayoutQ_, + class LayoutK_, + class LayoutV_, + class Fusion, + class DispatchPolicy, + class... Options +> +struct FmhaBuilder; + +template< + class Element, + class ElementAccumulator, + class TileShape, // BlockQO, BlockKV, BlockHead + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulator, + ElementAccumulator, + TileShape, + cute::tuple>, + cute::tuple>, + cute::tuple>, + Fusion, + cutlass::gemm::KernelTma, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTma; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulator, typename CollectiveMainloop::TileShapePV>; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTma; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options... +> { + + using CollectiveMainloop = cutlass::fmha::collective::FmhaMainloopTmaWarpSpecialized< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, LayoutQ, LayoutK, LayoutV, + Fusion, Options...>; + + using CollectiveEpilogue = cutlass::fmha::collective::FmhaFwdEpilogue< + Element, ElementAccumulatorPV, typename CollectiveMainloop::TileShapePV>; + + static constexpr bool kIsPersistent = find_option_t::value; + using TileScheduler = std::conditional_t; + + using Kernel = cutlass::fmha::kernel::FmhaKernelTmaWarpSpecialized; +}; + +template< + class Element, + class ElementAccumulatorQK, + class ElementAccumulatorPV, + class TileShape, // BlockQO, BlockKV, BlockHead + class LayoutQ, + class LayoutK, + class LayoutV, + class Fusion, + class... Options +> +struct FmhaBuilder< + Element, + ElementAccumulatorQK, + ElementAccumulatorPV, + TileShape, + LayoutQ, + LayoutK, + LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + Options... +> { + using Kernel = typename FmhaBuilder< + Element, ElementAccumulatorQK, ElementAccumulatorPV, + TileShape, + LayoutQ, LayoutK, LayoutV, + Fusion, + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + Options..., + Option, + Option + >::Kernel; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3ef1946ba6056c51bf882124d7d132e98b984ae2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_convert.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdConvert { + + struct Arguments { + tuple problem_size; + + const ElementAccumulator* ptr_src_dQ; + tuple stride_src_dQ; + const ElementAccumulator* ptr_src_dK; + tuple stride_src_dK; + const ElementAccumulator* ptr_src_dV; + tuple stride_src_dV; + + Element* ptr_dest_dQ; + tuple stride_dest_dQ; + Element* ptr_dest_dK; + tuple stride_dest_dK; + Element* ptr_dest_dV; + tuple stride_dest_dV; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static const int kBlockSeq = 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kNumThreadsD = 16; + static const int kNumThreadsSeq = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 4; + + static const int kIterationsSeq = kBlockSeq / kNumThreadsSeq; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(size<0>(params.problem_size), size<1>(params.problem_size), ceil_div(std::max(size<2>(params.problem_size), size<3>(params.problem_size)), kBlockSeq)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsSeq, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + template + CUTLASS_DEVICE void copy(Params const& params, const ElementAccumulator* ptr_src, StrideSrc const& stride_src, Element* ptr_dest, StrideDest const& stride_dest, int count) { + auto ptr_src_bh = ptr_src + get<0>(stride_src) * blockIdx.x + get<1>(stride_src) * blockIdx.y; + auto ptr_dest_bh = ptr_dest + get<0>(stride_dest) * blockIdx.x + get<1>(stride_dest) * blockIdx.y; + + for (int idx_s_t = threadIdx.y; idx_s_t < kBlockSeq; idx_s_t += kNumThreadsSeq) { + int idx_s = idx_s_t + kBlockSeq * blockIdx.z; + if (idx_s >= count) continue; + auto ptr_src_bhs = ptr_src_bh + idx_s * get<2>(stride_src); + auto ptr_dest_bhs = ptr_dest_bh + idx_s * get<2>(stride_dest); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + ElementAccumulator value_src[kElementsPerLoad]; + Element value_dest[kElementsPerLoad]; + + using VecSrc = uint_bit_t * kElementsPerLoad>; + using VecDest = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_src) = *reinterpret_cast(&ptr_src_bhs[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + value_dest[v] = value_src[v]; + } + + *reinterpret_cast(&ptr_dest_bhs[idx_d]) = *reinterpret_cast(value_dest); + } + } + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + if (params.ptr_src_dQ != nullptr) { + copy(params, params.ptr_src_dQ, params.stride_src_dQ, params.ptr_dest_dQ, params.stride_dest_dQ, get<2>(params.problem_size)); + } + if (params.ptr_src_dK != nullptr) { + copy(params, params.ptr_src_dK, params.stride_src_dK, params.ptr_dest_dK, params.stride_dest_dK, get<3>(params.problem_size)); + } + if (params.ptr_src_dV != nullptr) { + copy(params, params.ptr_src_dV, params.stride_src_dV, params.ptr_dest_dV, params.stride_dest_dV, get<3>(params.problem_size)); + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4e276a6235107463369ce50b7ab5e50b33a55350 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_bwd_sum_OdO.hpp @@ -0,0 +1,134 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/layout.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template +struct FmhaKernelBwdSumOdO { + + struct Arguments { + cute::tuple problem_size; + + const Element* ptr_O; + cute::tuple stride_O; + const Element* ptr_dO; + cute::tuple stride_dO; + + ElementAccumulator* ptr_sum_OdO; + cute::tuple stride_sum_OdO; + }; + + using Params = Arguments; + + using ClusterShape = Shape<_1, _1, _1>; + static constexpr int SharedStorageSize = 0; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = 128; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static const int kBlockQ = 16; + + static const int kNumThreadsD = 8; + static const int kNumThreadsQ = MaxThreadsPerBlock / kNumThreadsD; + static const int kElementsPerLoad = 2; + + static const int kIterationsQ = kBlockQ / kNumThreadsQ; + + static bool can_implement(Arguments const& args) { + return get<4>(args.problem_size) % kElementsPerLoad == 0; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(ceil_div(size<2>(params.problem_size), kBlockQ), size<1>(params.problem_size), size<0>(params.problem_size)); + return grid; + } + + static dim3 get_block_shape() { + dim3 block(kNumThreadsD, kNumThreadsQ, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return args; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { + auto ptr_O_bh = params.ptr_O + blockIdx.y * get<1>(params.stride_O) + blockIdx.z * get<0>(params.stride_O); + auto ptr_dO_bh = params.ptr_dO + blockIdx.y * get<1>(params.stride_dO) + blockIdx.z * get<0>(params.stride_dO); + auto ptr_sum_OdO_bh = params.ptr_sum_OdO + blockIdx.y * get<1>(params.stride_sum_OdO) + blockIdx.z * get<0>(params.stride_sum_OdO); + + CUTLASS_PRAGMA_UNROLL + for (int idx_q_t = threadIdx.y; idx_q_t < kBlockQ; idx_q_t += kNumThreadsQ) { + int idx_q = idx_q_t + kBlockQ * blockIdx.x; + if (idx_q >= get<2>(params.problem_size)) continue; + ElementAccumulator acc = 0; + auto ptr_O_bhq = ptr_O_bh + idx_q * get<2>(params.stride_O); + auto ptr_dO_bhq = ptr_dO_bh + idx_q * get<2>(params.stride_dO); + auto ptr_sum_OdO_bhq = ptr_sum_OdO_bh + idx_q * get<2>(params.stride_sum_OdO); + + for (int idx_d = threadIdx.x * kElementsPerLoad; idx_d < get<4>(params.problem_size); idx_d += kElementsPerLoad * kNumThreadsD) { + Element value_O[kElementsPerLoad]; + Element value_dO[kElementsPerLoad]; + + using Vec = uint_bit_t * kElementsPerLoad>; + *reinterpret_cast(value_O) = *reinterpret_cast(&ptr_O_bhq[idx_d]); + *reinterpret_cast(value_dO) = *reinterpret_cast(&ptr_dO_bhq[idx_d]); + + for (int v = 0; v < kElementsPerLoad; v++) { + acc += value_O[v] * value_dO[v]; + } + } + + for (int i = 1; i < kNumThreadsD; i *= 2) { + acc += __shfl_xor_sync((uint32_t)-1, acc, i, kNumThreadsD); + } + + if (threadIdx.x == 0) { + *ptr_sum_OdO_bhq = acc; + } + } + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..eeee7fa286a7d1039ff0518f4446a9edd3055521 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma.hpp @@ -0,0 +1,226 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_tile_scheduler.hpp" +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class... Options +> +struct FmhaKernelTma { + + // Options + static constexpr int kBlocksPerSM = find_option_t, Options...>::value; + + using Element = typename CollectiveMainloop::Element; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + + using TileScheduler = IndividualTileScheduler; + + using StagesQ = typename CollectiveMainloop::StagesQ; + using Stages = typename CollectiveMainloop::Stages; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineQ = typename CollectiveMainloop::MainloopPipelineQ; + + using SmemLayoutQ = typename CollectiveMainloop::SmemLayoutQ; + using SmemLayoutK = typename CollectiveMainloop::SmemLayoutK; + + struct SharedStorage { + union { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue; + }; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + using PipelineStorageQ = typename MainloopPipelineQ::SharedStorage; + alignas(16) PipelineStorage pipeline_storage; + alignas(16) PipelineStorageQ pipeline_storage_q; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + using PipelineParamsQ = typename MainloopPipelineQ::Params; + using PipelineStateQ = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = kBlocksPerSM; + static const int MaxThreadsPerBlock = CollectiveMainloop::MaxThreadsPerBlock; + using ArchTag = cutlass::arch::Sm90; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int thread_idx = int(threadIdx.x); + + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_group_thread_idx = thread_idx % cutlass::NumThreadsPerWarpGroup; + int lane_predicate = cute::elect_one_sync(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + + PipelineParamsQ pipeline_params_q; + pipeline_params_q.transaction_bytes = size(SmemLayoutQ{}(_,_,_0{})) * sizeof(Element); // Q + pipeline_params_q.role = MainloopPipelineQ::ThreadCategory::ProducerConsumer; + pipeline_params_q.is_leader = warp_group_thread_idx == 0; + pipeline_params_q.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParams pipeline_params; + pipeline_params.transaction_bytes = size(SmemLayoutK{}(_,_,_0{})) * sizeof(Element); // KV + pipeline_params.role = MainloopPipeline::ThreadCategory::ProducerConsumer; + pipeline_params.is_leader = warp_group_thread_idx == 0; + pipeline_params.num_consumers = cutlass::NumThreadsPerWarpGroup; + + MainloopPipelineQ pipeline_q(storage.pipeline_storage_q, pipeline_params_q, Shape<_1, _1, _1>{}); + MainloopPipeline pipeline(storage.pipeline_storage, pipeline_params, ClusterShape{}); + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::ProducerConsumer; + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineState smem_pipe_read; + PipelineState smem_pipe_write = cutlass::make_producer_start_state(); + + PipelineStateQ smem_pipe_read_q; + PipelineStateQ smem_pipe_write_q = cutlass::make_producer_start_state(); + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + auto blk_coord = tile_scheduler.get_block_coord(); + + CollectiveMainloop collective_mainloop; + auto result = collective_mainloop.compute( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline, smem_pipe_read, smem_pipe_write, + pipeline_q, smem_pipe_read_q, smem_pipe_write_q, + storage.mainloop + ); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, blk_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.epilogue); +#endif + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b96d71135f7396af3f9295f8fcaa3d4528fc381 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_kernel_tma_warpspecialized.hpp @@ -0,0 +1,422 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/arch/arch.h" + +#include "../kernel/fmha_options.hpp" + +namespace cutlass::fmha::kernel { + +using namespace cute; + +template< + class CollectiveMainloop, + class CollectiveEpilogue, + class TileScheduler, + class... Options +> +struct FmhaKernelTmaWarpSpecialized { + + // Options + static constexpr bool kIsEpilogueLocked = find_option_t::value; + static constexpr bool kLoadsQSeparately = find_option_t::value; + + + static const int NumLoadWarpGroups = 1; + static constexpr int NumMmaWarpGroups = CollectiveMainloop::NumMmaWarpGroups; + + using TileShape = typename CollectiveMainloop::TileShape; + using ClusterShape = typename CollectiveMainloop::ClusterShape; + + using MainloopPipelineOuter = typename CollectiveMainloop::MainloopPipelineQ; + using MainloopPipelineInner = typename CollectiveMainloop::MainloopPipeline; + using MainloopPipelineReducer = cutlass::PipelineAsync<2>; + + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< + StagesPerMathWarpGroup, NumMmaWarpGroups>; + + struct TensorStorageStruct { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + union TensorStorageUnion { + typename CollectiveMainloop::SharedStorage mainloop; + typename CollectiveEpilogue::TensorStorage epilogue[NumMmaWarpGroups]; + }; + using TensorStorage = std::conditional_t; + + struct SharedStorage { + TensorStorage tensors; + + using PipelineStorageInner = typename MainloopPipelineInner::SharedStorage; + using PipelineStorageOuter = typename MainloopPipelineOuter::SharedStorage; + using PipelineStorageReducer = typename MainloopPipelineReducer::SharedStorage; + + alignas(16) PipelineStorageInner pipeline_storage_inner; + alignas(16) PipelineStorageOuter pipeline_storage_outer; + alignas(16) PipelineStorageReducer pipeline_storage_reducer; + + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + + alignas(16) cutlass::arch::ClusterBarrier load_warp_barrier; + + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + alignas(16) EpiLoadPipelineStorage epi_load; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using ProblemShape = cute::tuple; + + struct Arguments { + ProblemShape problem_size; + typename CollectiveMainloop::Arguments mainloop; + typename CollectiveEpilogue::Arguments epilogue; + KernelHardwareInfo hw_info; + }; + + struct Params { + ProblemShape problem_size; + typename CollectiveMainloop::Params mainloop; + typename CollectiveEpilogue::Params epilogue; + typename TileScheduler::Params tile_scheduler; + }; + + using PipelineParamsInner = typename MainloopPipelineInner::Params; + using PipelineStateInner = typename cutlass::PipelineState; + using PipelineParamsOuter = typename MainloopPipelineOuter::Params; + using PipelineStateOuter = typename cutlass::PipelineState; + using PipelineParamsReducer = typename MainloopPipelineReducer::Params; + using PipelineStateReducer = typename cutlass::PipelineState; + + static const int MinBlocksPerMultiprocessor = 1; + static const int MaxThreadsPerBlock = (NumMmaWarpGroups + NumLoadWarpGroups) * cutlass::NumThreadsPerWarpGroup; + using ArchTag = cutlass::arch::Sm90; + + static constexpr uint32_t LoadRegisterRequirement = 40 - 2 * 8; + static constexpr uint32_t TotalRegisterSupply = (64*1024 / MaxThreadsPerBlock / MinBlocksPerMultiprocessor / 8) * 8 * MaxThreadsPerBlock / cutlass::NumThreadsPerWarpGroup; + static constexpr uint32_t MmaRegisterRequirement = ((TotalRegisterSupply - LoadRegisterRequirement) / NumMmaWarpGroups / 8) * 8; + + static size_t get_workspace_size(Arguments const& args) { return 0; } + static cutlass::Status initialize_workspace(Arguments const&, void*, cudaStream_t) { + return cutlass::Status::kSuccess; + } + + static bool can_implement(Arguments const& args) { + return CollectiveMainloop::can_implement(args.problem_size, args.mainloop); + } + + static dim3 get_grid_shape(Params const& params) { + return TileScheduler::get_grid_shape(params.tile_scheduler); + } + + static dim3 get_block_shape() { + dim3 block(MaxThreadsPerBlock, 1, 1); + return block; + } + + static Params to_underlying_arguments(Arguments const& args, void* workspace) { + return Params{ + args.problem_size, + CollectiveMainloop::to_underlying_arguments(args.problem_size, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_size, args.epilogue, workspace), + TileScheduler::to_underlying_arguments(args.problem_size, args.hw_info, ClusterShape{}, TileShape{}) + }; + } + + CUTLASS_DEVICE void operator()(const Params ¶ms, char* smem) { +#if ! defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) + printf("ERROR : Arch conditional MMA instruction used without targeting appropriate compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer0 = 1, + Consumer1 = 2, + Consumer2 = 3, + Consumer3 = 4, + }; + enum class ProducerWarpRole { + LoadKV = 1, + Reducer = 0, + MaybeLoadQ = 2, // is kLoadsQSeparately is true, this warp loads Q (otherwise warp 0 does it) + MainloopEpilogue = 3, + }; + + static constexpr ProducerWarpRole WarpRoleLoadQ = kLoadsQSeparately ? ProducerWarpRole::MaybeLoadQ : ProducerWarpRole::LoadKV; + + TileScheduler tile_scheduler{params.tile_scheduler}; + + // Shared memory. + auto& storage = *reinterpret_cast(smem); + + int lane_idx = cutlass::canonical_lane_idx(); + int warp_idx = cutlass::canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % cutlass::NumWarpsPerWarpGroup; + int warp_group_idx = cutlass::canonical_warp_group_idx(); + auto warp_group_role = WarpGroupRole(warp_group_idx); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int consumer_warp_group_idx = warp_group_idx - (int) WarpGroupRole::Consumer0; + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + } + + PipelineParamsOuter pipeline_params_outer; + pipeline_params_outer.transaction_bytes = CollectiveMainloop::kOuterLoadBytes; + pipeline_params_outer.is_leader = lane_predicate && (producer_warp_role == WarpRoleLoadQ); + pipeline_params_outer.num_consumers = cutlass::NumThreadsPerWarpGroup; + + PipelineParamsInner pipeline_params_inner; + pipeline_params_inner.transaction_bytes = CollectiveMainloop::kInnerLoadBytes; + pipeline_params_inner.is_leader = lane_predicate && (producer_warp_role == ProducerWarpRole::LoadKV); + pipeline_params_inner.num_consumers = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + + PipelineParamsReducer pipeline_params_reducer; + pipeline_params_reducer.producer_arv_count = NumMmaWarpGroups * cutlass::NumThreadsPerWarpGroup; + pipeline_params_reducer.consumer_arv_count = cutlass::NumThreadsPerWarp; + + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::MainloopEpilogue) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::LoadKV) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == WarpRoleLoadQ) { + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Reducer) { + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Consumer; + } + if (warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + pipeline_params_inner.role = MainloopPipelineInner::ThreadCategory::Consumer; + pipeline_params_outer.role = MainloopPipelineOuter::ThreadCategory::Consumer; + pipeline_params_reducer.role = MainloopPipelineReducer::ThreadCategory::Producer; + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + + MainloopPipelineOuter pipeline_outer(storage.pipeline_storage_outer, pipeline_params_outer, Shape<_1, _1, _1>{}); + MainloopPipelineInner pipeline_inner(storage.pipeline_storage_inner, pipeline_params_inner, ClusterShape{}); + MainloopPipelineReducer pipeline_reducer(storage.pipeline_storage_reducer, pipeline_params_reducer); + + // State variables used for iterating the circular buffer + // smem_pipe_read / release is used by the consumer of SMEM data - i.e MMA + // smem_pipe_write is used by the producer of SMEM data - i.e TMA + PipelineStateInner smem_pipe_read_inner; + PipelineStateInner smem_pipe_write_inner = cutlass::make_producer_start_state(); + + PipelineStateOuter smem_pipe_read_outer; + PipelineStateOuter smem_pipe_write_outer = cutlass::make_producer_start_state(); + + PipelineStateReducer smem_pipe_read_reducer; + PipelineStateReducer smem_pipe_write_reducer = cutlass::make_producer_start_state(); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = consumer_warp_group_idx; + params_math_wg_order_barrier.group_size = cutlass::NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(storage.math_wg_order, params_math_wg_order_barrier); + + // Epilogue Load pipeline + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + epi_load_pipeline_params.transaction_bytes = CollectiveEpilogue::TmaTransactionBytes; + EpiLoadPipeline epi_load_pipeline(storage.epi_load, epi_load_pipeline_params); + + if constexpr (kLoadsQSeparately) { + if ((warp_idx == 0) && lane_predicate) { + storage.load_warp_barrier.init(2 * cutlass::NumThreadsPerWarp); + } + cutlass::arch::fence_barrier_init(); + } + + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer blocks in the Cluster + // and to finish smem init + if constexpr (size(ClusterShape{}) > 1) { + cute::cluster_arrive_relaxed(); + cute::cluster_wait(); + } + else { + __syncthreads(); + } + + CollectiveMainloop collective_mainloop; + + if (warp_group_role == WarpGroupRole::Producer) { + cutlass::arch::warpgroup_reg_dealloc(); + if (producer_warp_role == ProducerWarpRole::LoadKV) { + bool do_barrier = kLoadsQSeparately; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.template load_kv_maybe_q( + block_rank_in_cluster, + blk_coord, params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_write_inner, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } + else if (kLoadsQSeparately && (producer_warp_role == ProducerWarpRole::MaybeLoadQ)) { + bool do_barrier = true; + + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.load_maybe_q( + blk_coord, params.mainloop, params.problem_size, + pipeline_outer, smem_pipe_write_outer, + storage.tensors.mainloop, + storage.load_warp_barrier, do_barrier + ); + do_barrier = false; + } + } else if (producer_warp_role == ProducerWarpRole::Reducer) { + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + collective_mainloop.reduce( + blk_coord, params.mainloop, params.problem_size, + pipeline_reducer, smem_pipe_read_reducer, + storage.tensors.mainloop + ); + } + } + } + else if ( + warp_group_role == WarpGroupRole::Consumer0 || + warp_group_role == WarpGroupRole::Consumer1 || + warp_group_role == WarpGroupRole::Consumer2 || + warp_group_role == WarpGroupRole::Consumer3 + ) { + cutlass::arch::warpgroup_reg_alloc(); + CUTLASS_PRAGMA_NO_UNROLL + for (; tile_scheduler.is_valid(); ++tile_scheduler) { + auto blk_coord = tile_scheduler.get_block_coord(); + auto wg_coord = blk_coord; + + constexpr int kOuterLoads = CollectiveMainloop::kOuterLoads; + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(0 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(1 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(2 * kOuterLoads); + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(3 * kOuterLoads); + } + + constexpr int wg_dim = is_constant<0, decltype(get<1>(wg_coord))>::value ? 0 : 1; + auto& wg_block = get(wg_coord); + if (warp_group_role == WarpGroupRole::Consumer0) { + wg_block = NumMmaWarpGroups * wg_block + 0; + } + else if (warp_group_role == WarpGroupRole::Consumer1) { + wg_block = NumMmaWarpGroups * wg_block + 1; + } + else if (warp_group_role == WarpGroupRole::Consumer2) { + wg_block = NumMmaWarpGroups * wg_block + 2; + } + else if (warp_group_role == WarpGroupRole::Consumer3) { + wg_block = NumMmaWarpGroups * wg_block + 3; + } + + auto result = collective_mainloop.compute( + blk_coord, wg_coord, + params.mainloop, params.problem_size, + pipeline_inner, smem_pipe_read_inner, + pipeline_outer, smem_pipe_read_outer, + pipeline_reducer, smem_pipe_write_reducer, + storage.tensors.mainloop, + math_wg_order_barrier + ); + + if (warp_group_role == WarpGroupRole::Consumer0) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 0)); + } + if constexpr (NumMmaWarpGroups >= 2) { + if (warp_group_role == WarpGroupRole::Consumer1) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 1)); + } + } + if constexpr (NumMmaWarpGroups >= 3) { + if (warp_group_role == WarpGroupRole::Consumer2) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 2)); + } + } + if constexpr (NumMmaWarpGroups >= 4) { + if (warp_group_role == WarpGroupRole::Consumer3) { + smem_pipe_read_outer.advance(kOuterLoads * (NumMmaWarpGroups - 3)); + } + } + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.wait(); + + CollectiveEpilogue epilogue; + epilogue(typename CollectiveMainloop::TileShapePV{}, wg_coord, + result, typename CollectiveMainloop::TiledMmaPV{}, + params.problem_size, params.epilogue, + epi_load_pipeline, storage.tensors.epilogue[consumer_warp_group_idx]); + + if constexpr (kIsEpilogueLocked) ; math_wg_order_barrier.arrive(); + } + } +#endif + } +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6746bf797c5905cbc005c787852fab6a0c7f1a09 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_options.hpp @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass::fmha::kernel { + +template +struct find_option; + +template +struct find_option { + using option_value = Default; +}; + +template +struct find_option : + std::conditional_t< + Option::tag == kTag, + Option, + find_option + > +{}; + +template +using find_option_t = typename find_option::option_value; + +enum class Tag { + kIsPersistent, + kNumMmaWarpGroups, + kLoadsQSeparately, + + kIsMainloopLocked, + kIsEpilogueLocked, + + kStagesQ, + kStagesKV, + + kEpilogueKind, + + kBlocksPerSM, + kClusterM, + + kAccQK +}; + +template +struct Option { + static constexpr auto tag = kTag; + using option_value = Value; +}; + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b46e27e142973e675a9824a4eec4e784f1615d7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/kernel/fmha_tile_scheduler.hpp @@ -0,0 +1,204 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.h" + +namespace cutlass::fmha::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +struct IndividualTileScheduler { + + struct Params { + dim3 grid; + }; + + bool valid_ = true; + + CUTLASS_DEVICE + IndividualTileScheduler(Params const&) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + dim3 grid(round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)), size<0>(problem_size), size<1>(problem_size)); + return Params{ grid }; + } + + static dim3 get_grid_shape(Params const& params) { + return params.grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return valid_; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return make_coord(blockIdx.x, _0{}, make_coord(blockIdx.y, blockIdx.z)); + } + + CUTLASS_DEVICE + IndividualTileScheduler& operator++() { + valid_ = false; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +struct PersistentTileScheduler { + + struct Params { + int num_blocks; + FastDivmod divmod_m_block; + FastDivmod divmod_b; + FastDivmod divmod_h; + + KernelHardwareInfo hw_info; + }; + + int block_idx = 0; + Params params; + + CUTLASS_DEVICE + PersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + hw_info.sm_count = sm_count; + + int num_m_blocks = cutlass::round_up(ceil_div(size<2>(problem_size), size<0>(tile_shape)), size<0>(cluster_shape)); + int num_blocks = num_m_blocks * size<0>(problem_size) * size<1>(problem_size); + + return Params { + num_blocks, + { num_m_blocks}, { size<0>(problem_size) }, { size<1>(problem_size) }, + hw_info + }; + } + + static dim3 get_grid_shape(Params const& params) { + dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1); + return grid; + } + + CUTLASS_DEVICE + bool is_valid() { + return block_idx < params.num_blocks; + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + int block_decode = block_idx; + int m_block, bidb, bidh; + params.divmod_m_block(block_decode, m_block, block_decode); + params.divmod_b(block_decode, bidb, block_decode); + params.divmod_h(block_decode, bidh, block_decode); + return make_coord(m_block, _0{}, make_coord(bidb, bidh)); + } + + CUTLASS_DEVICE + PersistentTileScheduler& operator++() { + block_idx += gridDim.x; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct TileSchedulerBwdAdapter { + + using Params = typename Base::Params; + + Base base_; + + CUTLASS_DEVICE + TileSchedulerBwdAdapter(Params const& params) : base_(params) {} + + template + static Params to_underlying_arguments( + ProblemSize const& problem_size, KernelHardwareInfo hw_info, + ClusterShape const& cluster_shape, TileShape const& tile_shape) + { + using namespace cute; + return Base::to_underlying_arguments(select<0,1,3,2,4>(problem_size), hw_info, select<1,0,2>(cluster_shape), select<1,0,2>(tile_shape)); + } + + static dim3 get_grid_shape(Params const& params) { + return Base::get_grid_shape(params); + } + + CUTLASS_DEVICE + bool is_valid() { + return base_.is_valid(); + } + + CUTLASS_DEVICE + auto get_block_coord() { + using namespace cute; + return select<1,0,2>(base_.get_block_coord()); + } + + CUTLASS_DEVICE + TileSchedulerBwdAdapter& operator++() { + ++base_; + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::fmha::kernel diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..503fb3aad4976fb02d80312a3ad7e53300c80e6c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_bwd_reference.hpp @@ -0,0 +1,357 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, /* class TensorDK, class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dQ_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, /* TensorDK mDK, TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDQ); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mDQ); idx_Q += gridDim.x) { + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_K] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDQ); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mK(idx_K, idx_D, idx_L); + } + mDQ(idx_Q, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, */ class TensorDK, /* class TensorDV, */ + class Fusion +> +void __global__ fmha_bwd_reference_dK_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, */ TensorDK mDK, /* TensorDV mDV, */ + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDK); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDK); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + ElementAccumulator acc_dov = 0; + ElementAccumulator acc_doo = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + acc_dov += mDO(idx_Q, idx_D0, idx_L) * mV(idx_K, idx_D0, idx_L); + acc_doo += mDO(idx_Q, idx_D0, idx_L) * mO(idx_Q, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L)) * softmax_scale * (acc_dov - acc_doo)); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDK); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mQ(idx_Q, idx_D, idx_L); + } + mDK(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /* class TensorDQ, class TensorDK, */ class TensorDV, + class Fusion +> +void __global__ fmha_bwd_reference_dV_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /* TensorDQ mDQ, TensorDK mDK, */ TensorDV mDV, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + Element softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + for (int idx_L = blockIdx.y; idx_L < size<2>(mDV); idx_L += gridDim.y) { + for (int idx_K = blockIdx.x; idx_K < size<0>(mDV); idx_K += gridDim.x) { + + for (int idx_Q = threadIdx.x; idx_Q < size<0>(mDO); idx_Q += blockDim.x) { + ElementAccumulator acc_qk = 0; + for (int idx_D0 = 0; idx_D0 < size<1>(mK); idx_D0++) { + acc_qk += mQ(idx_Q, idx_D0, idx_L) * mK(idx_K, idx_D0, idx_L); + } + + auto id = make_identity_tensor(make_shape(1, 1)); + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc_qk; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + acc_qk = frag(0); + + mS[idx_Q] = static_cast(exp(softmax_scale * acc_qk - mLSE(idx_Q, idx_L))); + } + + __syncthreads(); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mDV); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_Q = 0; idx_Q < size<0>(mDO); idx_Q++) { + acc += mS[idx_Q] * mDO(idx_Q, idx_D, idx_L); + } + mDV(idx_K, idx_D, idx_L) = acc; + } + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /**/ class TensorDQ, /** / class TensorDK, / ** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dQ( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /**/ TensorDQ mDQ, /** / TensorDK mDK, / ** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDQ), size<2>(mDQ), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dQ_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dQ_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / **/ class TensorDK, /** / class TensorDV, / **/ + class Fusion +> +void fmha_bwd_reference_dK( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / **/ TensorDK mDK, /** / TensorDV mDV, / **/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDK), size<2>(mDK), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dK_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dK_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + /** / class TensorDQ, / ** / class TensorDK, / **/ class TensorDV, /**/ + class Fusion +> +void fmha_bwd_reference_dV( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + /** / TensorDQ mDQ, / ** / TensorDK mDK, / **/ TensorDV mDV, /**/ + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mDV), size<2>(mDV), 1); + dim3 block(256); + int shared_mem = size<0>(mDO) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_bwd_reference_dV_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_bwd_reference_dV_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, class TensorK, class TensorV, + class TensorO, class TensorLSE, class TensorDO, + class TensorDQ, class TensorDK, class TensorDV, + class Fusion +> +void fmha_bwd_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, TensorDO mDO, + TensorDQ mDQ, TensorDK mDK, TensorDV mDV, + Fusion fusion +) { + fmha_bwd_reference_dQ(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDQ, fusion); + fmha_bwd_reference_dK(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDK, fusion); + fmha_bwd_reference_dV(problem_shape, mQ, mK, mV, mO, mLSE, mDO, mDV, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94f443713d6ccbbd05e36d0ace33f8534fd8db22 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/fmha_reference.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void __global__ fmha_reference_kernel( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + using Element = typename TensorO::value_type; + using ElementAccumulator = typename TensorLSE::value_type; + + extern __shared__ char mS_mem[]; + Element* mS = reinterpret_cast(mS_mem); + + ElementAccumulator softmax_scale = static_cast(1.0 / sqrt(1.0 * size<1>(mO))); + + auto id = make_identity_tensor(make_shape(1, 1)); + for (int idx_L = blockIdx.y; idx_L < size<2>(mO); idx_L += gridDim.y) { + for (int idx_Q = blockIdx.x; idx_Q < size<0>(mO); idx_Q += gridDim.x) { + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_D = 0; idx_D < size<1>(mK); idx_D++) { + acc += mQ(idx_Q, idx_D, idx_L) * mK(idx_K, idx_D, idx_L); + } + auto frag = make_tensor(Shape<_1, _1>{}); + frag(0) = acc; + fusion.before_softmax(frag, make_tensor(id.data() + make_arithmetic_tuple(idx_Q, idx_K), id.layout()), problem_shape); + mS[idx_K] = static_cast(frag(0) * softmax_scale); + } + + __syncthreads(); + + ElementAccumulator maxS = -std::numeric_limits::infinity(); + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + maxS = std::max(maxS, mS[idx_K]); + } + if (maxS == -std::numeric_limits::infinity()) maxS = 0; + + __syncthreads(); + + for (int idx_K = threadIdx.x; idx_K < size<0>(mK); idx_K += blockDim.x) { + mS[idx_K] = static_cast(exp(mS[idx_K] - maxS)); + } + + __syncthreads(); + + ElementAccumulator sum = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + sum += mS[idx_K]; + } + + Element scale = static_cast(1.0 / sum); + + for (int idx_D = threadIdx.x; idx_D < size<1>(mO); idx_D += blockDim.x) { + ElementAccumulator acc = 0; + for (int idx_K = 0; idx_K < size<0>(mK); idx_K++) { + acc += mS[idx_K] * mV(idx_K, idx_D, idx_L) * scale; + } + mO(idx_Q, idx_D, idx_L) = static_cast(acc); + } + + if (threadIdx.x == 0) { + mLSE(idx_Q, idx_L) = log(sum) + maxS; + } + + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShape, + class TensorQ, + class TensorK, + class TensorV, + class TensorO, + class TensorLSE, + class Fusion +> +void fmha_reference( + ProblemShape problem_shape, + TensorQ mQ, TensorK mK, TensorV mV, + TensorO mO, TensorLSE mLSE, + Fusion fusion +) { + using namespace cute; + + dim3 grid(size<0>(mO), size<2>(mO), 1); + dim3 block(256); + int shared_mem = size<0>(mK) * sizeof(typename TensorO::value_type); + + if (shared_mem >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << shared_mem); + auto result = cudaFuncSetAttribute( + fmha_reference_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + shared_mem); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return; + } + } + + fmha_reference_kernel<<>>(problem_shape, mQ, mK, mV, mO, mLSE, fusion); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5eb2413afe1db1c5ff87d1bda32c9815c51d39e4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/88_hopper_fmha/reference/reference_abs_error.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include +#include "cutlass/util/device_memory.h" + +template +__global__ void reference_abs_diff_kernel( + Element* data, Element* data_ref, size_t count, + double* max_diff, double* sum_diff, + bool print_diff +) { + double thread_max_diff = 0; + double thread_sum_diff = 0; + + __shared__ double block_max_diff; + __shared__ double block_sum_diff; + + for (size_t i = threadIdx.x + blockIdx.x * blockDim.x; i < count; i += blockDim.x * gridDim.x) { + double diff = fabs(data[i] - data_ref[i]); + if (print_diff) if (diff != diff || diff > 0.01f) printf("difference at %lld: %f ... %f vs %f\n", static_cast(i), diff, (double)data[i], (double)data_ref[i]); + thread_max_diff = fmax(diff, thread_max_diff); + thread_sum_diff += diff; + } + + for (int i = 0; i < blockDim.x; i++) { + if (i == threadIdx.x) { + if (i == 0) { + block_max_diff = thread_max_diff; + block_sum_diff = thread_sum_diff; + } else { + block_max_diff = fmax(block_max_diff, thread_max_diff); + block_sum_diff += thread_sum_diff; + } + } + __syncthreads(); + } + + if (threadIdx.x == 0) { + atomicAdd(sum_diff, block_sum_diff); + + for (;;) { + unsigned long long prev = *reinterpret_cast(max_diff); + double prev_diff = reinterpret_cast(prev); + double new_max_diff = fmax(block_max_diff, prev_diff); + unsigned long long found = atomicCAS(reinterpret_cast(max_diff), prev, reinterpret_cast(new_max_diff)); + if (found == prev) break; + } + } +} + +template +void reference_abs_diff( + cutlass::DeviceAllocation const& data, + cutlass::DeviceAllocation const& data_ref, + double& max_diff, double& mean_diff +) { + static bool kPrintDiff = getenv("REF_PRINT_DIFF") && atoi(getenv("REF_PRINT_DIFF")) == 1; + + cutlass::DeviceAllocation result; + result.reset(2); + assert(data.size() == data_ref.size()); + + cudaError_t err = cudaMemset(result.get(), 0, result.size() * sizeof(double)); + if (err != cudaSuccess) { + std::cerr << "Memset failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + dim3 block(256, 1, 1); + dim3 grid(1024, 1, 1); + reference_abs_diff_kernel<<>>( + data.get(), data_ref.get(), data.size(), + result.get(), result.get() + 1, kPrintDiff); + + err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { + std::cerr << "Difference kernel failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + double result_host[2]; + err = cudaMemcpy(result_host, result.get(), result.size() * sizeof(double), cudaMemcpyDefault); + if (err != cudaSuccess) { + std::cerr << "Copy failed. Last CUDA error: " + << cudaGetErrorString(err) << std::endl; + max_diff = mean_diff = 1e20; + return; + } + + max_diff = result_host[0]; + mean_diff = result_host[1] / static_cast(data.size()); +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..35f05442cdfe6d445f3a91f0891e27ff7cc4d330 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/dist_gemm_helpers.h @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Benchmark helpers for Distributed GEMM + + A delay kernel to gate all GEMMs across devices, controlled by a flag that + the host will set off once it launches DistGEMM across all devices. + + DistGpuTimer extends cutlass's existing cudaEvent-based timer to multiple devices. +*/ + +#pragma once +#include "cutlass/cutlass.h" +#include +#include +#include CUDA_STD_HEADER(atomic) + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Delay kernel +///////////////////////////////////////////////////////////////////////////////////////////////// + +using AtomicBoolean = cuda::atomic; + +__global__ void delay_kernel(const AtomicBoolean* atomic_flag_ptr) { + while (not atomic_flag_ptr->load()) { + __nanosleep(40); + } +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Distributed GPU Timer +/// Sets up cuda events for multiple processors. +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct DistGpuTimer { + int _primary_device; + cudaEvent_t _start[NP]; + cudaEvent_t _stop[NP]; + + /// Constructor + DistGpuTimer() + { + CUDA_CHECK(cudaGetDevice(&_primary_device)); + for (int device = 0; device < NP; ++device) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaEventCreate(&_start[device])); + CUDA_CHECK(cudaEventCreate(&_stop[device])); + } + CUDA_CHECK(cudaSetDevice(_primary_device)); + } + + /// Destructor + ~DistGpuTimer() + { + for (int device = 0; device < NP; ++device) { + CUDA_CHECK(cudaSetDevice(device)); + CUDA_CHECK(cudaEventDestroy(_start[device])); + CUDA_CHECK(cudaEventDestroy(_stop[device])); + } + CUDA_CHECK(cudaSetDevice(_primary_device)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(int device, cudaStream_t stream) { + assert(device >= 0 && device < NP); + CUDA_CHECK(cudaEventRecord(_start[device], stream)); + } + + /// Stop the timer + void stop(int device, cudaStream_t stream) { + assert(device >= 0 && device < NP); + CUDA_CHECK(cudaEventRecord(_stop[device], stream)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis(int device) { + assert(device >= 0 && device < NP); + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop[device])); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start[device], _stop[device])); + return elapsed; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Generic device-to-device data movement kernel based for CuTe tensors. +/// +/// NOTE: this kernel assigns one element copy to every thread, and is by no means +/// an efficient way of copying tensors. It should only be used for convenience in +/// reference checks. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream); + + +template +__global__ void device_copy_kernel(TensorSource const tensor_source, + TensorDestination tensor_destination) { + auto linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + using ElementSrc = typename TensorSource::value_type; + using ElementDst = typename TensorDestination::value_type; + NumericConverter converter; + if (linear_idx < size(tensor_source)) { + tensor_destination(linear_idx) = converter(tensor_source(linear_idx)); + } +} + +template +void device_copy(TensorSource tensor_source, + TensorDestination tensor_destination, + cudaStream_t stream) { + + assert(tensor_source.size() == tensor_destination.size()); + + auto numel = tensor_source.size(); + static constexpr int NumThreads = 128; + auto grid_size = cute::ceil_div(numel, NumThreads); + + dim3 grid(grid_size); + dim3 block(NumThreads); + device_copy_kernel<<>>(tensor_source, tensor_destination); +} + +} //namespace cutlass diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..46fb64009735e4459ac6927e999921474f1ccd2d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(ceil_div(shape, Int{}), ceil_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/helper.h b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/helper.h new file mode 100644 index 0000000000000000000000000000000000000000..aec1c7197c6e8140f3bceb00e891ca5a94cc6d92 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/common/helper.h @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cuda_runtime.h" +#include + +/** + * Panic wrapper for unwinding CUTLASS errors + */ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ +#define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f6332002bb77b30728071de83cae689ce5e68594 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/cute/tutorial/blackwell/example_utils.hpp @@ -0,0 +1,105 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include // CuTe tensor implementation +#include + +template +void +reference_gemm(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD & tensor_D, + Alpha alpha, Beta beta) +{ + using namespace cute; + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + AccType c = AccType(0.f); + for (int k = 0; k < size<1>(tensor_A); ++k) { + c += tensor_A(m,k) * tensor_B(n,k); + } + tensor_D(m,n) = alpha * c + beta * tensor_C(m,n); + } + } +} + +template +bool +compare_results(TensorA const& tensor_A, TensorB const& tensor_B, + TensorC const& tensor_C, TensorD const& tensor_D, + RefTensorD const& ref_tensor_D, + bool print_diff = false) +{ + using namespace cute; + auto norm_A = matrix_inf_norm(tensor_A); + auto norm_B = matrix_inf_norm(tensor_B); + auto norm_C = matrix_inf_norm(tensor_C); + auto norm_D = matrix_inf_norm(tensor_D); + auto norm_ref_D = matrix_inf_norm(ref_tensor_D); + auto norm_diff = matrix_diff_inf_norm(tensor_D, ref_tensor_D); + + if (print_diff) { + for (int m = 0; m < size<0>(tensor_D); ++m) { + for (int n = 0; n < size<1>(tensor_D); ++n) { + std::cout << m << "," << n << " : " << tensor_D(m,n) << " vs. " << ref_tensor_D(m,n) << std::endl; + } + } + } + + std::cout << "norm (A) : " << norm_A.inf_norm << std::endl; + std::cout << "norm (B) : " << norm_B.inf_norm << std::endl; + std::cout << "norm (C) : " << norm_C.inf_norm << std::endl; + std::cout << "norm (D) : " << norm_D.inf_norm << std::endl; + std::cout << "norm (ref_D) : " << norm_ref_D.inf_norm << std::endl; + std::cout << "norm (D-ref_D) : " << norm_diff.inf_norm << std::endl; + + return (!norm_A.found_nan) && (!norm_B.found_nan) && + (!norm_C.found_nan) && (!norm_D.found_nan) && (!norm_ref_D.found_nan) && // There are no NaNs + (norm_A.inf_norm > 0.0) && (norm_B.inf_norm > 0.0) && + (norm_C.inf_norm > 0.0) && (norm_D.inf_norm > 0.0) && (norm_ref_D.inf_norm > 0.0) && // Values in tensors aren't zeros + (norm_diff.inf_norm <= 0.0); // Diff (ref_D-D) == 0 +} + +template +void +initialize_tensor(Tensor& tensor, cute::tuple value_range = {-2, 2}) +{ + using DataType = typename Tensor::element_type; + auto [min, max] = value_range; + for (int i = 0; i < cute::size(tensor); i++) { + tensor(i) = DataType(int((max-min)*(rand() / double(RAND_MAX)) + min)); + } +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py new file mode 100644 index 0000000000000000000000000000000000000000..79a23079077e8da6971e883c592def544aaa3579 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_bypass_dlpack.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import os +from typing import Tuple +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_ptr + + +""" +An Example demonstrating how to call off-the-shelf kernel by-passing dlpack protocol + +The example shows how to directly pass pointers from PyTorch tensors to off-the-shelf kernels +written by CuTe DSL with a thin customized wrapper jit function. The jit function will be +compiled with inline without introducing overhead. + +To run this example: + +.. code-block:: bash + + python examples/ampere/call_bypass_dlpack.py + + +It's worth to mention that by-passing dlpack protocol can resolve the issue that dlpack doesn't handle shape-1 +mode correctly. For example, the following code will fail, because dlpack will convert the shape-1 mode +with stride-1 which propagate alignment incorrectly. + +.. code-block:: python + + @cute.kernel + def fails_kernel(gX: cute.Tensor): + bidx, _, _ = cute.arch.block_idx() + mX = gX[None, bidx, None] # We wish to retain alignment + # assert mX.iterator.alignment == 16 + + + @cute.jit + def fails(gX_: cute.Tensor): + gX = gX_ + fails_kernel(gX).launch(grid=(1, 1, 1), block=(128, 1, 1)) + + + gX_torch = torch.rand((128, 1, 128), device="cuda", dtype=torch.bfloat16) + fails(from_dlpack(gX_torch, assumed_align=16)) + +""" + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +@cute.jit +def tensor_op_gemm_wrapper( + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + c_ptr: cute.Pointer, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + l: cutlass.Int32, +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {(m, n, k, l)}") + + # Assume alignment of shape to call tensorop_gemm example + m = cute.assume(m, divby=8) + n = cute.assume(n, divby=8) + + # Torch is row major + a_layout = cute.make_ordered_layout((m, k, l), order=(0, 1, 2)) + b_layout = cute.make_ordered_layout((n, k, l), order=(0, 1, 2)) + c_layout = cute.make_ordered_layout((m, n, l), order=(1, 0, 2)) + mA = cute.make_tensor(a_ptr, layout=a_layout) + mB = cute.make_tensor(b_ptr, layout=b_layout) + mC = cute.make_tensor(c_ptr, layout=c_layout) + + print(f"[DSL INFO] mA: {mA}") + print(f"[DSL INFO] mB: {mB}") + print(f"[DSL INFO] mC: {mC}") + + tensor_op_gemm = TensorOpGemm( + a_ptr.value_type, c_ptr.value_type, cutlass.Float32, (2, 2, 1) + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {a_ptr.value_type}") + print(f"[DSL INFO] Output dtype: {c_ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {cutlass.Float32}") + print(f"[DSL INFO] Atom layout: {(2, 2, 1)}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + # (M,K,L) + a = torch.randn( + mnkl[3], mnkl[2], mnkl[0], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,K,L) + b = torch.randn( + mnkl[3], mnkl[2], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(2, 1, 0) + # (N,M,L) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch.float16, device="cuda" + ).permute(1, 2, 0) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + a_ptr = make_ptr( + cutlass.Float16, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + b_ptr = make_ptr( + cutlass.Float16, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + c_ptr = make_ptr( + cutlass.Float16, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32 + ) + tensor_op_gemm_wrapper(a_ptr, b_ptr, c_ptr, *mnkl) + torch.cuda.synchronize() + + ref = torch.einsum("mkl,nkl->mnl", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..ee71e53fe87931a2ddf0fed944991a3ce313a29e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/call_from_jit.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +""" +Demonstrating JIT GEMM Implementation with Static Shape Wrapper + +This example illustrates how to invoke a JIT-compiled GEMM implementation through a wrapper function +with static shapes. It showcases the integration between PyTorch and CuTe tensors in a JIT context. + +Key features demonstrated: +1. Seamless conversion between PyTorch and CuTe tensors using the JitArgument protocol +2. Integration of static shape GEMM operations within a JIT-compiled wrapper function + +Core components: +- BufferWithLayout: Handles memory buffer management with configurable stride ordering +- tensor_op_gemm_wrapper: JIT-compiled entry point that orchestrates the GEMM operation + +Usage: + +.. code-block:: bash + + python examples/ampere/call_from_jit.py + +Default configuration: +- Batch dimension (L): 16 +- Matrix dimensions: M=512, N=256, K=128 +- Precision: Float16 inputs with Float32 accumulation + +Requirements: +- CUDA-capable GPU +- PyTorch with CUDA support +""" + +import os +import sys +from typing import Type, Tuple + +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.torch import dtype as torch_dtype +from cutlass.cute.runtime import make_ptr + + +# Add the current directory to sys.path +sys.path.append(os.path.dirname(os.path.abspath(__file__))) +from tensorop_gemm import TensorOpGemm + + +class BufferWithLayout: + def __init__(self, ptr: cute.Pointer, stride_order: tuple[int, int, int]): + self.ptr = ptr + + # static properties + self.stride_order = stride_order + + def to_tensor( + self, shape: tuple[int, int, int], *, loc=None, ip=None + ) -> cute.Tensor: + assert len(shape) == len(self.stride_order), ( + f"Shape {shape} and stride_order {self.stride_order} must have the " + "same rank." + ) + layout = cute.make_ordered_layout(shape, self.stride_order) + # permute (l, mn, k) -> (mn, k, l) + res = cute.make_tensor(self.ptr, cute.select(layout, mode=[1, 2, 0])) + return res + + # Implement JitArgument Protocol and DynamicExpression Protocol + + def __c_pointers__(self): + """Get the C pointers for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the C pointers + from the underlying pointer object. + + This is required for user to define a custom data type which can pass to JIT function. + When JIT compiled function is called, JIT executor will call this method to get raw pointers + to underlying data object. + + Following condition must be satisfied: + + len(__c_pointers__()) == len(__get_mlir_types__()) == len(__extract_mlir_values__()) + + :return: The C pointers from the underlying pointer object + :rtype: Any + """ + return self.ptr.__c_pointers__() + + def __get_mlir_types__(self): + """Get the MLIR types for the underlying pointer. + + This method is part of the JitArgument Protocol and returns the MLIR types + used for compiler to generate code. It must match the type of the underlying pointers + returned by __c_pointers__(). + + :return: The MLIR types from the underlying pointer object + :rtype: Any + """ + return self.ptr.__get_mlir_types__() + + def __extract_mlir_values__(self): + """Extract MLIR values from the underlying pointer. + + This method is part of the DynamicExpression Protocol and extracts MLIR values + from the underlying pointer object. + + It is used by compiler to generate function call in MLIR to another JIT function. + It must match the types returned by __get_mlir_types__(). + + :return: The MLIR values extracted from the underlying pointer object + :rtype: Any + """ + return self.ptr.__extract_mlir_values__() + + def __new_from_mlir_values__(self, values): + """Create a new BufferWithLayout instance from MLIR values. + + This method is part of the JitArgument & DynamicExpression Protocol and creates a new + BufferWithLayout instance with pointer initialized from the given MLIR values. + + It is used by compiler to generate function body in MLIR called by JIT function. + It must match the types returned by __c_pointers__() and __get_mlir_types__(). + code generator takes function arguments and reconstructs python object which is legal + inside function body. + + :param values: MLIR values to initialize the underlying pointer + :type values: Any + :return: A new BufferWithLayout instance with pointer initialized from values + :rtype: BufferWithLayout + """ + return BufferWithLayout( + self.ptr.__new_from_mlir_values__(values), self.stride_order + ) + + +@cute.jit +def tensor_op_gemm_wrapper( + buffer_a: BufferWithLayout, + buffer_b: BufferWithLayout, + buffer_c: BufferWithLayout, + mnkl: cutlass.Constexpr[tuple[int, int, int, int]], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: cutlass.Constexpr[tuple[int, int, int]], +): + print(f"\n[DSL INFO] Input Parameters:") + print(f"[DSL INFO] mnkl: {mnkl}") + print(f"[DSL INFO] buffer_a: {buffer_a}") + print(f"[DSL INFO] buffer_b: {buffer_b}") + print(f"[DSL INFO] buffer_c: {buffer_c}") + print(f"[DSL INFO] acc_dtype: {acc_dtype}") + print(f"[DSL INFO] atom_layout_mnk: {atom_layout_mnk}") + + mA = buffer_a.to_tensor(cute.select(mnkl, mode=[3, 0, 2])) + mB = buffer_b.to_tensor(cute.select(mnkl, mode=[3, 1, 2])) + mC = buffer_c.to_tensor(cute.select(mnkl, mode=[3, 0, 1])) + + print(f"\n[DSL INFO] Created Tensors:") + print(f"[DSL INFO] mA = {mA}") + print(f"[DSL INFO] mB = {mB}") + print(f"[DSL INFO] mC = {mC}") + + tensor_op_gemm = TensorOpGemm( + buffer_a.ptr.value_type, + buffer_c.ptr.value_type, + acc_dtype, + atom_layout_mnk, + ) + print(f"\n[DSL INFO] Created TensorOpGemm instance") + print(f"[DSL INFO] Input dtype: {buffer_a.ptr.value_type}") + print(f"[DSL INFO] Output dtype: {buffer_c.ptr.value_type}") + print(f"[DSL INFO] Accumulation dtype: {acc_dtype}") + print(f"[DSL INFO] Atom layout: {atom_layout_mnk}") + + # No need to compile inside jit function + tensor_op_gemm(mA, mB, mC) + print(f"\n[DSL INFO] Executed TensorOpGemm") + + +def run_tensor_op_gemm_wrapper(mnkl: Tuple[int, int, int, int]): + print(f"\nRunning TensorOpGemm test with:") + print(f"Tensor dimensions: {mnkl}") + + ab_dtype = cutlass.Float16 + c_dtype = cutlass.Float16 + + a = torch.randn( + mnkl[3], mnkl[0], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + b = torch.randn( + mnkl[3], mnkl[1], mnkl[2], dtype=torch_dtype(ab_dtype), device="cuda" + ) + c = torch.randn( + mnkl[3], mnkl[0], mnkl[1], dtype=torch_dtype(c_dtype), device="cuda" + ) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + buffer_a = BufferWithLayout( + make_ptr(ab_dtype, a.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + buffer_b = BufferWithLayout( + make_ptr(ab_dtype, b.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + buffer_c = BufferWithLayout( + make_ptr(c_dtype, c.data_ptr(), cute.AddressSpace.gmem, assumed_align=32), + (2, 1, 0), + ) + + tensor_op_gemm_wrapper( + buffer_a, + buffer_b, + buffer_c, + mnkl, # pass shape as static value + # no stride passing + cutlass.Float32, + (2, 2, 1), + ) + torch.cuda.synchronize() + + ref = torch.einsum("lmk,lnk->lmn", a, b) + torch.testing.assert_close(c, ref, atol=1e-05, rtol=1e-05) + print(f"\n[DSL INFO] Results verified successfully!") + print(f"First few elements of result: \n{c[:3, :3, :3]}") + + +if __name__ == "__main__": + run_tensor_op_gemm_wrapper((512, 256, 128, 16)) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py new file mode 100644 index 0000000000000000000000000000000000000000..dba505487bf53c5a2d63334d340449a0f1927f0f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/dynamic_smem_size.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass + +""" +Example of automatic shared memory size computation for configuring kernel launch + +This example demonstrates how to let the DSL automatically set shared memory + size for a kernel launch rather explicitly configuring it at launch time, + provided that developers are using `SmemAllocator` for all allocations. + +Usage: + python dynamic_smem_size.py # Show auto inference +""" + + +@cute.struct +class SharedData: + """A struct to demonstrate shared memory allocation.""" + + values: cute.struct.MemRange[cutlass.Float32, 64] # 256 bytes + counter: cutlass.Int32 # 4 bytes + flag: cutlass.Int8 # 1 byte + + +@cute.kernel +def kernel(): + """ + Example kernel that allocates shared memory. + The total allocation will be automatically calculated when smem=None. + """ + allocator = cutlass.utils.SmemAllocator() + + # Allocate various types of shared memory + shared_data = allocator.allocate(SharedData) + raw_buffer = allocator.allocate(512, byte_alignment=64) + int_array = allocator.allocate_array(element_type=cutlass.Int32, num_elems=128) + tensor_smem = allocator.allocate_tensor( + element_type=cutlass.Float16, + layout=cute.make_layout((32, 16)), + byte_alignment=16, + swizzle=None, + ) + return + + +@cute.kernel +def kernel_no_smem(): + """ + Example kernel that does not allocates shared memory. + The total allocation will be automatically calculated as 0 when smem=None. + """ + tidx, _, _ = cute.arch.block_idx() + if tidx == 0: + cute.printf("Hello world") + return + + +if __name__ == "__main__": + # Initialize CUDA context + cutlass.cuda.initialize_cuda_context() + + print("Launching kernel with auto smem size. (launch config `smem=None`)") + + # Compile the example + @cute.jit + def launch_kernel1(): + k = kernel() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + @cute.jit + def launch_kernel2(): + k = kernel_no_smem() + k.launch( + grid=(1, 1, 1), + block=(1, 1, 1), + ) + print(f"Kernel recorded internal smem usage: {k.smem_usage()}") + + cute.compile(launch_kernel1) + cute.compile(launch_kernel2) + + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py new file mode 100644 index 0000000000000000000000000000000000000000..596822ccec9fc5085ea4d6d2d9659992f7c7caab --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_add.py @@ -0,0 +1,403 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import time +from typing import Type + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack + +""" +An Elementwise Addition Example using CuTe DSL. + +This example kernel copies data from global memory to register memory (rmem), performs the elementwise +addition operation, and stores the result back to global memory. + +Primary goals of this example are to demonstrate how basic global memory copies can be expressed in +CuTe DSL and illustrate canonical partitioning patterns in CuTe. It also implements canonical +predication for tensors whose shape is not multiple of tile size to guard OOB reads. + +Thread-value (or TV) layouts are central to canonical partitioning patterns in CuTe. They provide a +mapping from thread and a thread's value to the set of coordinates within a tile that we have sliced +out from a data tensor. + +The input tensors are row-major layout, that leading dimension is the right most dimension. In order +to efficiently copy data from global memory, we must map threads contiguously on row dimension. + +Thread ID mapping to 2D coordinates with layout `(4,32):(32,1)`: + + +----+----+----+----+-----+----+ + | | 0 | 1 | 2 | ... | 31 | + +----+----+----+----+-----+----+ + | 0 | T0 | T1 | T2 | ... | T31| + +----+----+----+----+-----+----+ + | 1 |T32 |T33 |T34 | ... |T63 | + +----+----+----+----+-----+----+ + | 2 |T64 |T65 |T66 | ... |T95 | + +----+----+----+----+-----+----+ + | 3 |T96 |T97 |T98 | ... |T127| + +----+----+----+----+-----+----+ + +As Ampere GPU supports a maximum of 128bit per load/store instruction and each element is 32bit, we +can load 4 elements per instruction. Having additional contiguous values allows for vectorization +across threads (coalesced accesses) and is required for saturating the memory bandwidth. + +We use `(4,4):(4,1)` as the val layout in this example. Notice that the major mode is the same as +the major mode of the input tensor - without which vectorization would not be possible. + +If you already know the TV layout you want to use for your tiled copy, CuTe DSL provides utility +`cute.make_layout_tv` to build the tiled copy type around it and the atom of your choice. + +.. code-block:: python + + thr_layout = cute.make_layout((4, 32), stride=(32, 1)) + val_layout = cute.make_layout((4, 4), stride=(4, 1)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + # Tile input tensor to thread blocks: ((TileM,TileN),(RestM,RestN)) + gA = cute.zipped_divide(mA, tiler_mn) + +Then we can build tiled copy for input and output tensors with `cute.make_tiled_copy_tv` utility, which +infers the tiler and tv layout for the tiled copy automatically, where `tiler` is the tile size per thread +block and `tv_layout` is the TV layout which maps thread index and inter-thread index of data array per +thread to logical coordinates of elements in input and output tensors. + +.. code-block:: python + + blkA = gA[((None, None), bidx)] # (TileM,TileN) + + copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + + # get slice of tiled_copy_A for current thread + thr_copy_A = tiled_copy_A.get_slice(tidx) + + # partition per thread block tensor as source of tiled copy + thrA = thr_copy_A.partition_S(blkA) + + # allocate fragment for gmem->rmem + frgA = cute.make_fragment_like(thrA) + + # copy data from global memory to register memory + cute.copy(copy_atom_load, thrA, frgA) + + +To run this example: + +.. code-block:: bash + + python examples/ampere/elementwise_add.py --M 3 --N 12 + python examples/ampere/elementwise_add.py --M 1024 --N 512 + python examples/ampere/elementwise_add.py --M 1024 --N 1024 --benchmark --warmup_iterations 2 --iterations 1000 + +To collect performance with NCU profiler: + +.. code-block:: bash + + # Don't iterate too many times when profiling with ncu + ncu python examples/ampere/elementwise_add.py --M 2048 --N 2048 --benchmark --iterations 10 --skip_ref_check +""" + + +@cute.kernel +def elementwise_add_kernel( + gA: cute.Tensor, + gB: cute.Tensor, + gC: cute.Tensor, + cC: cute.Tensor, # coordinate tensor + shape: cute.Shape, + thr_layout: cute.Layout, + val_layout: cute.Layout, +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # slice for CTAs + # logical id -> address + blk_coord = ((None, None), bidx) + blkA = gA[blk_coord] # (TileM,TileN) + blkB = gB[blk_coord] # (TileM,TileN) + blkC = gC[blk_coord] # (TileM,TileN) + blkCrd = cC[blk_coord] # (TileM, TileN) + + # Note: these prints only run at compile/jit time + print(f"[DSL INFO] Sliced Tensors per thread block:") + print(f"[DSL INFO] blkA = {blkA.type}") + print(f"[DSL INFO] blkB = {blkB.type}") + print(f"[DSL INFO] blkC = {blkC.type}") + print(f"[DSL INFO] blkCrd = {blkCrd.type}") + + # # declare the atoms which will be used later for memory copy + copy_atom_load = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gA.element_type) + copy_atom_store = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gC.element_type) + + tiled_copy_A = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_B = cute.make_tiled_copy_tv(copy_atom_load, thr_layout, val_layout) + tiled_copy_C = cute.make_tiled_copy_tv(copy_atom_store, thr_layout, val_layout) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + + thrA = thr_copy_A.partition_S(blkA) + thrB = thr_copy_B.partition_S(blkB) + thrC = thr_copy_C.partition_S(blkC) + + # allocate fragments for gmem->rmem + frgA = cute.make_fragment_like(thrA) + frgB = cute.make_fragment_like(thrB) + frgC = cute.make_fragment_like(thrC) + + thrCrd = thr_copy_C.partition_S(blkCrd) + frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + + print(f"[DSL INFO] Sliced Tensors per thread:") + print(f"[DSL INFO] thrA = {thrA.type}") + print(f"[DSL INFO] thrB = {thrB.type}") + print(f"[DSL INFO] thrC = {thrC.type}") + print(f"[DSL INFO] thrCrd = {thrCrd.type}") + + for i in range(0, cute.size(frgPred), 1): + val = cute.elem_less(thrCrd[i], shape) + frgPred[i] = val + + # Print per thread predicate mask + # if tidx == 0 and bidx == 0: + # cute.printf("block_dim = {}", cute.arch.grid_dim()) + # cute.printf("shape = {}", shape) + # cute.print_tensor(thrA) + # cute.print_tensor(thrB) + # cute.print_tensor(frgPred) + + ########################################################## + # Move data to reg address space + ########################################################## + + cute.copy(copy_atom_load, thrA, frgA, pred=frgPred) + cute.copy(copy_atom_load, thrB, frgB, pred=frgPred) + + # if tidx == 0 and bidx == 0: + # cute.print_tensor(frgA) + # cute.print_tensor(frgB) + + # Load data before use. The compiler will optimize the copy and load + # operations to convert some memory ld/st into register uses. + result = frgA.load() + frgB.load() + + # Save the results back to registers. Here we reuse b's registers. + frgC.store(result) + + # Copy the results back to c + cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + + +@cute.jit +def elementwise_add(mA, mB, mC, copy_bits: cutlass.Constexpr = 128): + dtype = mA.element_type + vector_size = copy_bits // dtype.width + + thr_layout = cute.make_ordered_layout((4, 32), order=(1, 0)) + val_layout = cute.make_ordered_layout((4, vector_size), order=(1, 0)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + print(f"[DSL INFO] Input Tensors:") + print(f"[DSL INFO] mA = {mA.type}") + print(f"[DSL INFO] mB = {mB.type}") + + print(f"[DSL INFO] Tiling Parameters:") + print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") + print(f"[DSL INFO] tv_layout = {tv_layout}") + + gA = cute.zipped_divide(mA, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gB = cute.zipped_divide(mB, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + gC = cute.zipped_divide(mC, tiler_mn) # ((TileM,TileN),(RestM,RestN)) + print(f"[DSL INFO] Tiled Tensors:") + print(f"[DSL INFO] gA = {gA.type}") + print(f"[DSL INFO] gB = {gB.type}") + print(f"[DSL INFO] gC = {gC.type}") + + idC = cute.make_identity_tensor(mC.shape) + cC = cute.zipped_divide(idC, tiler=tiler_mn) + print(f"[DSL INFO] coord tensor = {cC.type}") + + elementwise_add_kernel(gA, gB, gC, cC, mC.shape, thr_layout, val_layout).launch( + grid=[cute.size(gC, mode=[1]), 1, 1], + block=[cute.size(tv_layout, mode=[0]), 1, 1], + ) + + +def run_elementwise_add( + M, + N, + dtype: Type[cutlass.Numeric], + is_a_dynamic_layout=False, + is_b_dynamic_layout=False, + is_result_dynamic_layout=False, + skip_ref_check=False, + benchmark=True, + warmup_iterations=2, + iterations=200, +): + print(f"\nRunning Elementwise Add test with:") + print(f"Tensor dimensions: [{M}, {N}]") + print(f"Input and Output Data type: {dtype}") + + torch_dtype = cutlass_torch.dtype(dtype) + if dtype.is_integer: + a = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randint(0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype) + else: + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + + c = torch.zeros_like(a) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + if not is_a_dynamic_layout: + a_tensor = from_dlpack(a).mark_layout_dynamic() + else: + a_tensor = a + + if not is_b_dynamic_layout: + b_tensor = from_dlpack(b).mark_layout_dynamic() + else: + b_tensor = b + + if not is_result_dynamic_layout: + c_tensor = from_dlpack(c).mark_layout_dynamic() + else: + c_tensor = c + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_func = cute.compile(elementwise_add, a_tensor, b_tensor, c_tensor) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + print("Executing vector add kernel...") + + # Get current CUstream from torch + current_stream = cutlass_torch.current_stream() + + if not skip_ref_check: + compiled_func(a_tensor, b_tensor, c_tensor) + print("Verifying results...") + torch.testing.assert_close(a + b, c) + print("Results verified successfully!") + + if not benchmark: + return + + def generate_tensors(): + if dtype.is_integer: + a = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + b = torch.randint( + 0, 10, (M, N), device=torch.device("cuda"), dtype=torch_dtype + ) + else: + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + + c = torch.zeros_like(a) + + if not is_a_dynamic_layout: + a_tensor = from_dlpack(a).mark_layout_dynamic() + else: + a_tensor = a + + if not is_b_dynamic_layout: + b_tensor = from_dlpack(b).mark_layout_dynamic() + else: + b_tensor = b + + if not is_result_dynamic_layout: + c_tensor = from_dlpack(c).mark_layout_dynamic() + else: + c_tensor = c + + return testing.JitArguments(a_tensor, b_tensor, c_tensor) + + avg_time_us = testing.benchmark( + compiled_func, + workspace_generator=generate_tensors, + workspace_count=10, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + # Print execution results + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") + print( + f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time_us / 1e6) / 1e9:.2f} GB/s" + ) + print(f"First few elements of result: \n{c[:3, :3]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of elementwise add to demonstrate the numpy/pytorch as input for kernels" + ) + parser.add_argument("--M", default=1024, type=int) + parser.add_argument("--N", default=1024, type=int) + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument("--benchmark", action="store_true") + + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError(f"Ampere GPU is required to run this example!") + + run_elementwise_add( + args.M, + args.N, + dtype=cutlass.Float32, + is_a_dynamic_layout=True, + is_b_dynamic_layout=True, + is_result_dynamic_layout=True, + skip_ref_check=args.skip_ref_check, + benchmark=args.benchmark, + warmup_iterations=args.warmup_iterations, + iterations=args.iterations, + ) + print("\nPASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py new file mode 100644 index 0000000000000000000000000000000000000000..43c3b5fc76ce8bca016b43198194ca5f849b4a4a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/elementwise_apply.py @@ -0,0 +1,393 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import operator +import time +from typing import Type, List + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack + +""" +An Elementwise Apply Example using CuTe DSL. + +This example kernel demonstrates the meta-programming capability of the CuTe DSL by allowing +customization of elementwise operations through lambda functions. The kernel copies data from +global memory to register memory (rmem), applies a user-defined operation to the elements, +and stores the result back to global memory. + +Primary goals of this example: +1. Demonstrate meta-programming capability by passing lambda functions to customize elementwise operations +2. Show how to apply different operations (add, multiply, etc.) using the same kernel structure +3. Illustrate how to parameterize CUDA kernels with operation types at compile time + +To run this example: + +.. code-block:: bash + + # Run with addition operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op add + + # Run with multiplication operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op mul + + # Run with subtraction operation + python examples/ampere/elementwise_apply.py --M 1024 --N 512 --op sub + + # Benchmark performance + python examples/ampere/elementwise_apply.py --M 2048 --N 2048 --op add --benchmark --warmup_iterations 2 --iterations 10 + +The example demonstrates how to express complex CUDA kernels with customizable operations +while maintaining high performance through efficient memory access patterns. +""" + + +@cute.kernel +def elementwise_apply_kernel( + op: cutlass.Constexpr, + inputs: List[cute.Tensor], + gC: cute.Tensor, + cC: cute.Tensor, # coordinate tensor + shape: cute.Shape, + tv_layout: cute.Layout, # (tid, vid) -> logic coord +): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + # slice for CTAs + cta_coord = ((None, None), bidx) + # logical coord -> address + # Leverage the meta-programming capability of the DSL to slice the tensors for each input + # All for loops below on input tensors would be fully unrolled automatically at compile time + ctaInputs = [t[cta_coord] for t in inputs] # (TileM, TileN) + ctaC = gC[cta_coord] # (TileM, TileN) + ctaCrd = cC[cta_coord] # (TileM, TileN) + + print(f"[DSL INFO] Sliced Tensors per thread block:") + for i in cutlass.range_constexpr(len(ctaInputs)): + print(f"[DSL INFO] ctaInputs{i} = {ctaInputs[i].type}") + print(f"[DSL INFO] ctaC = {ctaC.type}") + print(f"[DSL INFO] ctaCrd = {ctaCrd.type}") + + # compose with CTA TV layout + # (tid, vid) -> address + tidfrgInputs = [cute.composition(t, tv_layout) for t in ctaInputs] + tidfrgC = cute.composition(ctaC, tv_layout) + tidfrgCrd = cute.composition(ctaCrd, tv_layout) + # print(f"{tv_layout = }") + # print(f"{tidfrgAB[0] = }") + + thr_coord = (tidx, (None, None)) + + # slice for threads + # vid -> address + thrInputs = [t[thr_coord] for t in tidfrgInputs] # (V) + thrC = tidfrgC[thr_coord] # (V) + thrCrd = tidfrgCrd[thr_coord] + + print(f"[DSL INFO] Sliced Tensors per thread:") + for i in cutlass.range_constexpr(len(thrInputs)): + print(f"[DSL INFO] thrInputs{i} = {thrInputs[i].type}") + print(f"[DSL INFO] thrC = {thrC.type}") + print(f"[DSL INFO] thrCrd = {thrCrd.type}") + + # allocate fragments for gmem->rmem + frgInputs = [cute.make_fragment_like(t, t.element_type) for t in thrInputs] + frgC = cute.make_fragment_like(thrC, gC.element_type) + frgPred = cute.make_fragment(thrCrd.shape, cutlass.Boolean) + + for i in cutlass.range(cute.size(frgPred), unroll=1): + frgPred[i] = cute.elem_less(thrCrd[i], shape) + + # if tidx == 0 and bidx == 0: + # cute.print_tensor(frgPred) + + ########################################################## + # Move data to reg address space + ########################################################## + + # declare the atoms which will be used later for memory copy + # Compile time validation: expect same element type for all input tensors so as to reuse the copy atom for load + assert all(t.element_type == inputs[0].element_type for t in inputs) + + copy_atom_load = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + inputs[0].element_type, + num_bits_per_copy=inputs[0].element_type.width, + ) + copy_atom_store = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + gC.element_type, + num_bits_per_copy=gC.element_type.width, + ) + + for thrInput, frgInput in zip(thrInputs, frgInputs): + cute.copy(copy_atom_load, thrInput, frgInput, pred=frgPred) + + # Load data before use. The compiler will optimize the copy and load + # operations to convert some memory ld/st into register uses. + result = op(*[frgInput.load() for frgInput in frgInputs]) + + # Save the results back to registers. Here we reuse b's registers. + frgC.store(result) + + # Copy the results back to c + cute.copy(copy_atom_store, frgC, thrC, pred=frgPred) + + +@cute.jit +def elementwise_apply( + op: cutlass.Constexpr, + a: cute.Tensor, + b: cute.Tensor, + result: cute.Tensor, + stream: cuda.CUstream, +): + """CUDA kernel applying binary operator on each element of two n-D input tensors in + CuTe Python and store to result tensor. + + :param op: Binary operator or lambda function to apply element-wise + :type op: cutlass.Constexpr + :param a: First input tensor + :type a: cute.Tensor + :param b: Second input tensor + :type b: cute.Tensor + :param result: Output tensor to store the results of op(a, b) + :type result: cute.Tensor + :return: None + :rtype: None + + .. code-block:: python + + # Example 1: Adding two tensors + x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32, device="cuda") + y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32, device="cuda") + result = torch.empty_like(x) + elementwise_apply(operator.add, from_dlpack(x), from_dlpack(y), from_dlpack(result)) + # result: + # tensor([[6.0, 8.0], + # [10.0, 12.0]], device='cuda:0') + + # Example 2: Using a lambda function + elementwise_apply(lambda a, b: a * a + b * b, from_dlpack(x), from_dlpack(y), from_dlpack(result)) + # result: + # tensor([[ 2., 8.], + # [ 54., 512.]], device='cuda:0') + """ + + # Baseline: naive TV layout + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (512, 4) tile + # * tidx maps to mode-0 but input layout is contiguous on mode-1, performance will be bad + # tv_layout = cute.make_layout((128, (4, 4)), stride=(4, (512, 1))) + # cta_tiler = (512, 4) + + # Opt-1: better TV layout with better 1D thread layout (SOL with 1D thread layout) + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (4, 512) tile + # * tidx maps to mode-1 which is leading mode of input tensor for coalesced load + # tv_layout = cute.make_layout((128, (4, 4)), stride=(16, (4, 1))) + # cta_tiler = (4, 512) + + # Opt-2: 2D tile but worse + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (128, 16) logical tile + # * V layout is bad as contiguous mode is not on right-most + # * `cute.copy` only supports vectorize when stride-1 of v-layout on right-most ) + # tv_layout = cute.make_layout(((32, 4), (4, 4)), stride=((4, 512), (1, 128))) + # cta_tiler = (128, 16) + + # Opt-3: SOL with 2D thread tile + # * mA layout: (4096, 4096):(4096, 1) + # * TV layout map to (16, 128) logical tile + # * tidx maps to mode-1 and input layout is contiguous on mode-1 for coalesced load-store + thr_layout = cute.make_layout((4, 32), stride=(32, 1)) + val_layout = cute.make_layout((4, 4), stride=(4, 1)) + tiler_mn, tv_layout = cute.make_layout_tv(thr_layout, val_layout) + + print(f"[DSL INFO] Input Tensors:") + print(f"[DSL INFO] a = {a.type}") + print(f"[DSL INFO] b = {b.type}") + print(f"[DSL INFO] result = {result.type}") + + print(f"[DSL INFO] Tiling Parameters:") + print(f"[DSL INFO] tiler_mn = {tiler_mn} per thread block") + print(f"[DSL INFO] tv_layout = {tv_layout}") + + gA = cute.zipped_divide(a, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + gB = cute.zipped_divide(b, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + gC = cute.zipped_divide(result, tiler_mn) # ((TileM, TileN), (RestM, RestN)) + + print(f"[DSL INFO] Tiled Tensors:") + print(f"[DSL INFO] gA = {gA.type}") + print(f"[DSL INFO] gB = {gB.type}") + print(f"[DSL INFO] gC = {gC.type}") + + idC = cute.make_identity_tensor(result.shape) + cC = cute.zipped_divide(idC, tiler=tiler_mn) + print(f"[DSL INFO] coord tensor = {cC.type}") + + # Launch the kernel asynchronously + # Async token(s) can also be specified as dependencies + elementwise_apply_kernel( + op, + [gA, gB], # Group input tensors into a list as a single argument + gC, + cC, + result.shape, + tv_layout, + ).launch( + grid=[cute.size(gC, mode=[1]), 1, 1], + block=[cute.size(tv_layout, mode=[0]), 1, 1], + stream=stream, + ) + + +def run_elementwise_apply_and_verify( + op, + M, + N, + dtype: Type[cutlass.Numeric], + skip_ref_check=False, + benchmark=True, + warmup_iterations=2, + iterations=100, +): + if not torch.cuda.is_available(): + raise RuntimeError(f"Ampere GPU is required to run this example!") + + # Create non default CUDA stream from PyTorch + torch_stream = torch.cuda.Stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + print(f"\nRunning Elementwise Apply test with:") + print(f"Tensor dimensions: [{M}, {N}]") + print(f"Input and Output Data type: {dtype}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Measurement iterations: {iterations}\n") + + torch_dtype = cutlass_torch.dtype(dtype) + + # Allocate tensors with random values. + a = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + b = torch.randn(M, N, device=torch.device("cuda"), dtype=torch_dtype) + c = torch.zeros_like(a) + + print(f"Input tensor shapes:") + print(f"a: {a.shape}, dtype: {a.dtype}") + print(f"b: {b.shape}, dtype: {b.dtype}") + print(f"c: {c.shape}, dtype: {c.dtype}\n") + + epsilon = 1.2 + if op in (operator.truediv, operator.floordiv): + b = torch.where(b == 0, torch.tensor(epsilon), b) + + print("Executing elementwise apply kernel...") + + if not skip_ref_check: + elementwise_apply( + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) + print("Verifying results...") + torch.testing.assert_close(op(a, b), c) + print("Results verified successfully!") + + if not benchmark: + return + + compiled_func = cute.compile( + elementwise_apply, + op, + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ) + + # When compiled we inlined op in the kernel, so we do not pass it when benchmarking + + avg_time_us = testing.benchmark( + compiled_func, + kernel_arguments=testing.JitArguments( + from_dlpack(a), + from_dlpack(b), + from_dlpack(c).mark_layout_dynamic(), + current_stream, + ), + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cuda_graphs=True, + stream=current_stream, + ) + + avg_time = avg_time_us / 1e3 + + # Print execution results + print(f"Kernel execution time: {avg_time:.4f} ms") + print( + f"Achieved memory throughput: {(3 * a.numel() * dtype.width // 8) / (avg_time / 1000) / 1e9:.2f} GB/s" + ) + print(f"First few elements of result: \n{c[:3, :3]}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of elementwise apply to demonstrate building elementwise kernels" + ) + parser.add_argument("--M", default=128, type=int) + parser.add_argument("--N", default=128, type=int) + parser.add_argument("--op", default="add", type=str) + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument("--benchmark", action="store_true") + args = parser.parse_args() + run_elementwise_apply_and_verify( + getattr(operator, args.op), + args.M, + args.N, + dtype=cutlass.Float32, + warmup_iterations=args.warmup_iterations, + iterations=args.iterations, + skip_ref_check=args.skip_ref_check, + benchmark=args.benchmark, + ) + print("\nPASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..c6f8ff4ccdc2949724133686984334411f8e6482 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/flash_attention_v2.py @@ -0,0 +1,1334 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from types import SimpleNamespace +from typing import Type, Union, Callable + +import torch +import cuda.bindings.driver as cuda +import cutlass.cute.testing as testing +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, warp +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils as utils + +""" +A flash attention v2 forward pass example for NVIDIA Ampere SM80 architecture using CUTE DSL. + +- Matrix Q is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension +- Matrix K is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension +- Matrix V is BxSkxNxH, B is batch dimension, Sk is key sequence length, N is number of heads, H is head dimension +- Matrix O is BxSqxNxH, B is batch dimension, Sq is query sequence length, N is number of heads, H is head dimension + +This kernel supports the following features: + - Utilizes CpAsync for efficient memory operations + - Utilizes Ampere's tensor core for matrix multiply-accumulate (MMA) operations + - Utilizes register pipeline to overlap shared memory-to-register transfers with computations. + - Leverages DSL to implement an integrated online softmax fusion pattern. + +This kernel works as follows: +1. Load Q and K matrices from global memory (GMEM) to shared memory (SMEM) using CpAsync operations. +2. Perform matrix multiply-accumulate (MMA) operations using tensor core instructions to compute intermediate result S. +3. Apply padding mask or causal mask to S during initial iterations. +4. Apply online softmax to S and rescale O using results from previous iteration. +5. Load V matrices and perform matrix multiply-accumulate (MMA) operations to compute final result O. +6. Normalize O after all iterations complete and store result back to global memory (GMEM). + +To run this example: + +.. code-block:: bash + + python examples/ampere/flash_attention_v2.py \ + --dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \ + --num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \ + --num_head 16 --softmax_scale 1.0 --is_causal + +The above command configures the model to use float16 for inputs and outputs. The problem dimensions +are set to: batch size of 1, query sequence length of 1280, key sequence length of 1536, head dimension +of 128, and 16 attention heads. The softmax scale is set to 1.0 and causal masking is enabled. The computation +uses tiles of size 128x128 for m and n dimensions, and utilizes 128 parallel threads. + +To collect the performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/flash_attention_v2.py \ + --dtype Float16 --head_dim 128 --m_block_size 128 --n_block_size 128 \ + --num_threads 128 --batch_size 1 --seqlen_q 1280 --seqlen_k 1536 \ + --num_head 16 --softmax_scale 1.0 --is_causal --skip_ref_check + +There are some constraints for this example: +* Only fp16 and bf16 data types are supported. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The log-sum-exp(for training) is not computed in the kernel. +* The values of `m_block_size`, `n_block_size`, and `head_dim` must be selected to stay within shared memory capacity limits. +* `m_block_size * 2` must be divisible by `num_threads`, otherwise the kernel will not be able to get the correct result. +""" + + +class FlashAttentionForwardAmpere: + def __init__( + self, + head_dim: int, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 128, + is_causal: bool = False, + ): + """Initializes the configuration for a flash attention v2 kernel. + + All contiguous dimensions must be at least 16 bytes aligned which indicates the head dimension + should be a multiple of 8. + + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + """ + self._head_dim = head_dim + self._m_block_size = m_block_size + self._n_block_size = n_block_size + # padding head_dim to a multiple of 32 as k_block_size + self._head_dim_padded = (head_dim + 31) // 32 * 32 + self._num_threads = num_threads + self._is_causal = is_causal + + @staticmethod + def can_implement( + dtype, head_dim, m_block_size, n_block_size, num_threads, is_causal + ) -> bool: + """Check if the kernel can be implemented with the given parameters. + + :param dtype: data type + :type dtype: cutlass.Numeric + :param head_dim: head dimension + :type head_dim: int + :param m_block_size: m block size + :type m_block_size: int + :param n_block_size: n block size + :type n_block_size: int + :param num_threads: number of threads + :type num_threads: int + :param is_causal: is causal + :type is_causal: bool + + :return: True if the kernel can be implemented, False otherwise + :rtype: bool + """ + # Check if data type is fp16 or bf16 + if dtype != cutlass.Float16 and dtype != cutlass.BFloat16: + return False + + # Check if head dimension is a multiple of 8 + if head_dim % 8 != 0: + return False + + # Check if number of threads is a multiple of 32 + if num_threads % 32 != 0: + return False + + # Check if block size setting is out of shared memory capacity + # Shared memory usage: Q tile + (K tile + V tile) where K and V use the same tile size + smem_usage = (m_block_size * head_dim + n_block_size * head_dim * 2) * 2 + smem_capacity = utils.get_smem_capacity_in_bytes("sm_80") + if smem_usage > smem_capacity: + return False + + # Check if twice the block size is divisible by the number of threads + if (m_block_size * 2) % num_threads != 0: + return False + + return True + + @cute.jit + def __call__( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + softmax_scale: cutlass.Float32, + stream: cuda.CUstream, + ): + """Configures and launches the flash attention v2 kernel. + + mQ/mK/mV/mO has same data types(supports fp16 and bf16) and same layout: + (batch_size, seqlen_q, num_head, head_dim):(seqlen_q * num_head * head_dim, num_head * head_dim, head_dim, 1) + + Prepares the shared memory layout, tiled copy atoms, tiled mma and shared memory storage. + Then launches the kernel function with the prepared parameters. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param softmax_scale: softmax scale + :type softmax_scale: cutlass.Float32 + """ + # Get the data type and check if it is fp16 or bf16 + if cutlass.const_expr( + not ( + mQ.element_type == mK.element_type == mV.element_type == mO.element_type + ) + ): + raise TypeError("All tensors must have the same data type") + if cutlass.const_expr( + not ( + mQ.element_type == cutlass.Float16 + or mQ.element_type == cutlass.BFloat16 + ) + ): + raise TypeError("Only Float16 or BFloat16 is supported") + self._dtype: Type[cutlass.Numeric] = mQ.element_type + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: Q/K/V + # /////////////////////////////////////////////////////////////////////////////// + smem_k_block_size = 64 if self._head_dim_padded % 64 == 0 else 32 + swizzle_bits = 3 if smem_k_block_size == 64 else 2 + sQ_layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + cute.make_layout((8, smem_k_block_size), stride=(smem_k_block_size, 1)), + ) + sQ_layout = cute.tile_to_shape( + sQ_layout_atom, + (self._m_block_size, self._head_dim_padded), + (0, 1), + ) + + sKV_layout_atom = sQ_layout_atom + sKV_layout = cute.tile_to_shape( + sKV_layout_atom, + (self._n_block_size, self._head_dim_padded), + (0, 1), + ) + + sO_layout = sQ_layout + + @cute.struct + class SharedStorage: + sQ: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sQ_layout)], 1024 + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + sV: cute.struct.Align[ + cute.struct.MemRange[self._dtype, cute.cosize(sKV_layout)], 1024 + ] + + # /////////////////////////////////////////////////////////////////////////////// + # GMEM Tiled copy: + # /////////////////////////////////////////////////////////////////////////////// + # Thread layouts for copies + universal_copy_bits = 128 + async_copy_elems = universal_copy_bits // self._dtype.width + # atom_async_copy: async copy atom for QKV load + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # atom_universal_copy: universal copy atom for O store + atom_universal_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self._dtype, + num_bits_per_copy=universal_copy_bits, + ) + # tQKV_layout: thread layout for QKV load + tQKV_shape_dim_1 = sQ_layout_atom.outer.shape[1] // async_copy_elems + tQKV_layout = cute.make_layout( + (self._num_threads // tQKV_shape_dim_1, tQKV_shape_dim_1), + stride=(tQKV_shape_dim_1, 1), + ) + # tO_layout: thread layout for O store + tO_layout = tQKV_layout + + # Value layouts for copies + vQKV_layout = cute.make_layout((1, async_copy_elems)) + vO_layout = vQKV_layout + + # gmem_tiled_copy_QKV: tiled copy for QKV load + gmem_tiled_copy_QKV = cute.make_tiled_copy_tv( + atom_async_copy, tQKV_layout, vQKV_layout + ) + # gmem_tiled_copy_O: tiled copy for O store + gmem_tiled_copy_O = cute.make_tiled_copy_tv( + atom_universal_copy, tO_layout, vO_layout + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled mma + # /////////////////////////////////////////////////////////////////////////////// + tiled_mma = cute.make_tiled_mma( + warp.MmaF16BF16Op(self._dtype, cutlass.Float32, (16, 8, 16)), + (self._num_threads // 32, 1, 1), + permutation_mnk=(self._num_threads // 32 * 16, 16, 16), + ) + + # grid_dim: (m_block, batch_size, num_head) + grid_dim = ( + cute.ceil_div(mQ.shape[1], self._m_block_size), + cute.size(mQ.shape[0]), + cute.size(mQ.shape[2]), + ) + LOG2_E = 1.4426950408889634074 + softmax_scale_log2 = softmax_scale * LOG2_E + self.kernel( + mQ, + mK, + mV, + mO, + softmax_scale_log2, + sQ_layout, + sKV_layout, + sO_layout, + gmem_tiled_copy_QKV, + gmem_tiled_copy_O, + tiled_mma, + SharedStorage, + ).launch( + grid=grid_dim, + block=[self._num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mQ: cute.Tensor, + mK: cute.Tensor, + mV: cute.Tensor, + mO: cute.Tensor, + softmax_scale_log2: cutlass.Float32, + sQ_layout: cute.ComposedLayout, + sKV_layout: cute.ComposedLayout, + sO_layout: cute.ComposedLayout, + gmem_tiled_copy_QKV: cute.TiledCopy, + gmem_tiled_copy_O: cute.TiledCopy, + tiled_mma: cute.TiledMma, + SharedStorage: cutlass.Constexpr, + ): + """Kernel function for flash attention v2. + + :param mQ: query tensor + :type mQ: cute.Tensor + :param mK: key tensor + :type mK: cute.Tensor + :param mV: value tensor + :type mV: cute.Tensor + :param mO: output tensor + :type mO: cute.Tensor + :param softmax_scale_log2: softmax scale log2 + :type softmax_scale_log2: cutlass.Float32 + :param sQ_layout: query layout + :type sQ_layout: cute.ComposedLayout + :param sKV_layout: key/value layout + :type sKV_layout: cute.ComposedLayout + :param sO_layout: output layout + :type sO_layout: cute.ComposedLayout + :param gmem_tiled_copy_QKV: tiled copy for QKV load + :type gmem_tiled_copy_QKV: cute.TiledCopy + :param gmem_tiled_copy_O: tiled copy for O store + :type gmem_tiled_copy_O: cute.TiledCopy + :param tiled_mma: tiled mma + :type tiled_mma: cute.TiledMma + :param SharedStorage: shared storage + :type SharedStorage: cutlass.Constexpr + """ + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + m_block, batch_size, num_head = cute.arch.block_idx() + + n_block_max = cute.ceil_div(mK.shape[1], self._n_block_size) + if self._is_causal: + n_block_max = min( + cute.ceil_div( + (m_block + 1) * self._m_block_size, + self._n_block_size, + ), + n_block_max, + ) + n_block = n_block_max - 1 + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # /////////////////////////////////////////////////////////////////////////////// + # (m_block_size, head_dim) + gQ = cute.local_tile( + mQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + # (n_block_size, head_dim, n_block) + gK = cute.local_tile( + mK[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + # (n_block_size, head_dim, n_block) + gV = cute.local_tile( + mV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (None, 0), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Get shared memory buffer + # /////////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + + storage = smem.allocate(SharedStorage) + sQ = storage.sQ.get_tensor(sQ_layout) + sK = storage.sK.get_tensor(sKV_layout) + sV = storage.sV.get_tensor(sKV_layout) + + # Transpose view of V to tensor with layout (head_dim, n_block_size) for tiled mma + sVt = cute.composition( + sV, + cute.make_layout( + (self._head_dim_padded, self._n_block_size), + stride=(self._n_block_size, 1), + ), + ) + + gmem_thr_copy_QKV = gmem_tiled_copy_QKV.get_slice(tidx) + # (CPY_Atom, CPY_M, CPY_K) + tQgQ = gmem_thr_copy_QKV.partition_S(gQ) + tQsQ = gmem_thr_copy_QKV.partition_D(sQ) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tKgK = gmem_thr_copy_QKV.partition_S(gK) + tKsK = gmem_thr_copy_QKV.partition_D(sK) + # (CPY_Atom, CPY_N, CPY_K, n_block) + tVgV = gmem_thr_copy_QKV.partition_S(gV) + tVsV = gmem_thr_copy_QKV.partition_D(sV) + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tSrQ = thr_mma.make_fragment_A(thr_mma.partition_A(sQ)) + tSrK = thr_mma.make_fragment_B(thr_mma.partition_B(sK)) + tOrVt = thr_mma.make_fragment_B(thr_mma.partition_B(sVt)) + acc_shape_O = thr_mma.partition_shape_C( + (self._m_block_size, self._head_dim_padded) + ) + acc_O = cute.make_fragment(acc_shape_O, cutlass.Float32) + acc_O.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Smem copy atom tiling + # /////////////////////////////////////////////////////////////////////////////// + smem_copy_atom_Q = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_K = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + self._dtype, + ) + smem_copy_atom_V = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + self._dtype, + ) + smem_tiled_copy_Q = cute.make_tiled_copy_A(smem_copy_atom_Q, tiled_mma) + smem_tiled_copy_K = cute.make_tiled_copy_B(smem_copy_atom_K, tiled_mma) + smem_tiled_copy_V = cute.make_tiled_copy_B(smem_copy_atom_V, tiled_mma) + + smem_thr_copy_Q = smem_tiled_copy_Q.get_slice(tidx) + smem_thr_copy_K = smem_tiled_copy_K.get_slice(tidx) + smem_thr_copy_V = smem_tiled_copy_V.get_slice(tidx) + + tSsQ = smem_thr_copy_Q.partition_S(sQ) + tSrQ_copy_view = smem_thr_copy_Q.retile(tSrQ) + tSsK = smem_thr_copy_K.partition_S(sK) + tSrK_copy_view = smem_thr_copy_K.retile(tSrK) + tOsVt = smem_thr_copy_V.partition_S(sVt) + tOrVt_copy_view = smem_thr_copy_V.retile(tOrVt) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for Q and KV + mcQ = cute.make_identity_tensor(mQ.layout.shape) + mcKV = cute.make_identity_tensor(mK.layout.shape) + cQ = cute.local_tile( + mcQ[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + cKV = cute.local_tile( + mcKV[batch_size, None, num_head, None], + (self._n_block_size, self._head_dim_padded), + (n_block, 0), + ) + + # Repeat the partitioning with identity layouts + tQcQ = gmem_thr_copy_QKV.partition_S(cQ) + tKVcKV = gmem_thr_copy_QKV.partition_S(cKV) + # Allocate predicate tensors for m and n, here we only allocate the tile of k, and do special process for mn. + # This is to reduce register pressure and gets 2-3% performance gain compared with allocating the whole tile. + tQpQ = cute.make_fragment( + cute.make_layout( + ( + tQsQ.shape[0][1], + cute.size(tQsQ, mode=[1]), + cute.size(tQsQ, mode=[2]), + ), + stride=(cute.size(tQsQ, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + tKVpKV = cute.make_fragment( + cute.make_layout( + ( + tKsK.shape[0][1], + cute.size(tKsK, mode=[1]), + cute.size(tKsK, mode=[2]), + ), + stride=(cute.size(tKsK, mode=[2]), 0, 1), + ), + cutlass.Boolean, + ) + # Set predicates for head_dim bounds, seqlen_q/k bounds is processed at the first tile. + for rest_v in cutlass.range_constexpr(tQpQ.shape[0]): + for rest_k in cutlass.range_constexpr(tQpQ.shape[2]): + tQpQ[rest_v, 0, rest_k] = cute.elem_less( + tQcQ[(0, rest_v), 0, rest_k][3], mQ.layout.shape[3] + ) + for rest_v in cutlass.range_constexpr(tKVpKV.shape[0]): + for rest_k in cutlass.range_constexpr(tKVpKV.shape[2]): + tKVpKV[rest_v, 0, rest_k] = cute.elem_less( + tKVcKV[(0, rest_v), 0, rest_k][3], mK.layout.shape[3] + ) + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads of the last mn-tile, where we take care of the mn residue + for m in cutlass.range_constexpr(cute.size(tQsQ.shape[1])): + if cute.elem_less(tQcQ[0, m, 0][1], mQ.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tQgQ[None, m, None], + tQsQ[None, m, None], + pred=tQpQ[None, m, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tQsQ[None, m, None].fill(0) + for n in cutlass.range_constexpr(cute.size(tKsK.shape[1])): + if cute.elem_less(tKVcKV[0, n, 0][1], mK.layout.shape[1]): + cute.copy( + gmem_tiled_copy_QKV, + tKgK[None, n, None, n_block], + tKsK[None, n, None], + pred=tKVpKV[None, n, None], + ) + else: + # Clear the smem tiles to account for predicated off loads + tKsK[None, n, None].fill(0) + + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # Softmax intermediate result: row_max and row_sum + # /////////////////////////////////////////////////////////////////////////////// + # shape: (atom_v_m * rest_m) + row_max = cute.make_fragment( + (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 + ) + # shape: (atom_v_m * rest_m) + row_sum = cute.make_fragment( + (acc_O.shape[0][0] * acc_O.shape[1]), cutlass.Float32 + ) + row_max.fill(-cutlass.Float32.inf) + row_sum.fill(0.0) + + # group parameters for compute_one_n_block + basic_params = SimpleNamespace( + m_block=m_block, + n_block=n_block, + mQ=mQ, + mK=mK, + batch_size=batch_size, + num_head=num_head, + ) + mma_params = SimpleNamespace( + thr_mma=thr_mma, + tiled_mma=tiled_mma, + tSrQ=tSrQ, + tSrK=tSrK, + tOrVt=tOrVt, + acc_O=acc_O, + ) + gmem_copy_params = SimpleNamespace( + gmem_tiled_copy_QKV=gmem_tiled_copy_QKV, + tKVcKV=tKVcKV, + tKgK=tKgK, + tKsK=tKsK, + tVgV=tVgV, + tVsV=tVsV, + tKVpKV=tKVpKV, + ) + smem_copy_params = SimpleNamespace( + smem_tiled_copy_Q=smem_tiled_copy_Q, + smem_tiled_copy_K=smem_tiled_copy_K, + smem_tiled_copy_V=smem_tiled_copy_V, + tSsQ=tSsQ, + tSrQ_copy_view=tSrQ_copy_view, + tSsK=tSsK, + tSrK_copy_view=tSrK_copy_view, + tOsVt=tOsVt, + tOrVt_copy_view=tOrVt_copy_view, + ) + softmax_params = SimpleNamespace( + row_max=row_max, + row_sum=row_sum, + softmax_scale_log2=softmax_scale_log2, + ) + + # Start processing of the first n-block. + # For performance reason, we separate out two kinds of iterations: + # those that need masking on S, and those that don't. + # We need masking on S for the very last block when K and V has length not multiple of n_block_size. + # We also need masking on S if it's causal, for the last ceil_div(m_block_size, n_block_size) blocks. + # We will have at least 1 "masking" iteration. + mask_steps = 1 + if cutlass.const_expr(self._is_causal): + mask_steps = cute.ceil_div(self._m_block_size, self._n_block_size) + + for n_tile in cutlass.range_constexpr(mask_steps): + n_block = n_block_max - n_tile - 1 + basic_params.n_block = n_block + if cutlass.const_expr(self._is_causal): + if n_block >= 0: + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=(n_tile == 0), + in_mask_steps=True, + ) + else: + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=True, + in_mask_steps=True, + ) + + # Start async loads of rest k-tiles in reverse order, no k-residue handling needed + for n_tile in range(mask_steps, n_block_max, 1): + n_block = n_block_max - n_tile - 1 + basic_params.n_block = n_block + self.compute_one_n_block( + basic_params, + mma_params, + gmem_copy_params, + smem_copy_params, + softmax_params, + is_first_n_block=False, + in_mask_steps=False, + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + # normalize acc_O by row_sum and calculate the lse + self.normalize_softmax(acc_O, row_sum) + # store acc_O + rO = cute.make_fragment_like(acc_O, self._dtype) + rO.store(acc_O.load().to(self._dtype)) + # reuse sQ's data iterator + sO = cute.make_tensor(sQ.iterator, sO_layout) + + # smem copy atom for O + smem_copy_atom_O = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), self._dtype + ) + # tiled copy atom for O + smem_tiled_copy_O = cute.make_tiled_copy_C(smem_copy_atom_O, tiled_mma) + smem_thr_copy_O = smem_tiled_copy_O.get_slice(tidx) + taccOrO = smem_thr_copy_O.retile(rO) + taccOsO = smem_thr_copy_O.partition_D(sO) + # copy acc O from rmem to smem with the smem copy atom + cute.copy( + smem_copy_atom_O, + taccOrO, + taccOsO, + ) + gO = cute.local_tile( + mO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + + gmem_thr_copy_O = gmem_tiled_copy_O.get_slice(tidx) + tOsO = gmem_thr_copy_O.partition_S(sO) + tOgO = gmem_thr_copy_O.partition_D(gO) + tOrO = cute.make_fragment_like(tOgO, self._dtype) + # sync before all smem stores are done. + cute.arch.barrier() + # load acc O from smem to rmem for wider vectorization + cute.copy( + gmem_tiled_copy_O, + tOsO, + tOrO, + ) + mcO = cute.make_identity_tensor(mO.layout.shape) + cO = cute.local_tile( + mcO[batch_size, None, num_head, None], + (self._m_block_size, self._head_dim_padded), + (m_block, 0), + ) + tOcO = gmem_thr_copy_O.partition_D(cO) + tOpO = cute.make_fragment( + cute.make_layout( + (tOgO.shape[0][1], tOgO.shape[1], tOgO.shape[2]), + stride=(tOgO.shape[2], 0, 1), + ), + cutlass.Boolean, + ) + for rest_v in cutlass.range_constexpr(tOpO.shape[0]): + for rest_n in cutlass.range_constexpr(cute.size(tOpO.shape[2])): + tOpO[rest_v, 0, rest_n] = cute.elem_less( + tOcO[(0, rest_v), 0, rest_n][3], mO.layout.shape[3] + ) + # copy acc O from rmem to gmem + for rest_m in cutlass.range_constexpr(cute.size(tOpO.shape[1])): + if cute.elem_less(tOcO[0, rest_m, 0][1], mO.layout.shape[1]): + cute.copy( + gmem_tiled_copy_O, + tOrO[None, rest_m, None], + tOgO[None, rest_m, None], + pred=tOpO[None, rest_m, None], + ) + + @cute.jit + def compute_one_n_block( + self, + basic_params: SimpleNamespace, + mma_params: SimpleNamespace, + gmem_copy_params: SimpleNamespace, + smem_copy_params: SimpleNamespace, + softmax_params: SimpleNamespace, + is_first_n_block: cutlass.Constexpr, + in_mask_steps: cutlass.Constexpr, + ): + """Compute one n_block of S/O. + + This function provides different variants for processing the first n block versus subsequent blocks, + as well as variants for handling masked and unmasked steps. + + :param basic_params: basic parameters + :type basic_params: SimpleNamespace + :param mma_params: mma parameters + :type mma_params: SimpleNamespace + :param gmem_copy_params: gmem copy parameters + :type gmem_copy_params: SimpleNamespace + :param smem_copy_params: smem copy parameters + :type smem_copy_params: SimpleNamespace + :param softmax_params: softmax parameters + :type softmax_params: SimpleNamespace + :param is_first_n_block: is first n block + :type is_first_n_block: cutlass.Constexpr + """ + acc_shape_S = mma_params.thr_mma.partition_shape_C( + (self._m_block_size, self._n_block_size) + ) + acc_S = cute.make_fragment(acc_shape_S, cutlass.Float32) + acc_S.fill(0.0) + + # wait for smem tile QK before mma calculation for S + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + # load smem tile V for O, special process for the first tile to avoid loading nan. + # The `if` here is a constexpr, won't be generated in the IR. + if is_first_n_block: + for n in cutlass.range_constexpr(cute.size(gmem_copy_params.tVsV.shape[1])): + if cute.elem_less( + gmem_copy_params.tKVcKV[0, n, 0][1], + basic_params.mK.layout.shape[1], + ): + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tVgV[None, n, None, basic_params.n_block], + gmem_copy_params.tVsV[None, n, None], + pred=gmem_copy_params.tKVpKV[None, n, None], + ) + else: + gmem_copy_params.tVsV[None, n, None].fill(0.0) + else: + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tVgV[None, None, None, basic_params.n_block], + gmem_copy_params.tVsV, + pred=gmem_copy_params.tKVpKV, + ) + + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # S gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # load first QK k-block from smem to rmem for mma + cute.copy( + smem_copy_params.smem_tiled_copy_Q, + smem_copy_params.tSsQ[None, None, 0], + smem_copy_params.tSrQ_copy_view[None, None, 0], + ) + cute.copy( + smem_copy_params.smem_tiled_copy_K, + smem_copy_params.tSsK[None, None, 0], + smem_copy_params.tSrK_copy_view[None, None, 0], + ) + # mma for S + for k in cutlass.range_constexpr(cute.size(smem_copy_params.tSsQ.shape[2])): + # load next QK k-block from smem to rmem for mma + k_next = (k + 1) % cute.size(smem_copy_params.tSsQ.shape[2]) + cute.copy( + smem_copy_params.smem_tiled_copy_Q, + smem_copy_params.tSsQ[None, None, k_next], + smem_copy_params.tSrQ_copy_view[None, None, k_next], + ) + cute.copy( + smem_copy_params.smem_tiled_copy_K, + smem_copy_params.tSsK[None, None, k_next], + smem_copy_params.tSrK_copy_view[None, None, k_next], + ) + cute.gemm( + mma_params.tiled_mma, + acc_S, + mma_params.tSrQ[None, None, k], + mma_params.tSrK[None, None, k], + acc_S, + ) + + # wait for smem tile V for O + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + + if basic_params.n_block > 0: + cute.copy( + gmem_copy_params.gmem_tiled_copy_QKV, + gmem_copy_params.tKgK[None, None, None, basic_params.n_block - 1], + gmem_copy_params.tKsK, + pred=gmem_copy_params.tKVpKV, + ) + cute.arch.cp_async_commit_group() + # /////////////////////////////////////////////////////////////////////////////// + # online softmax + # /////////////////////////////////////////////////////////////////////////////// + self.softmax_rescale_O( + basic_params, + mma_params, + softmax_params, + acc_S, + is_first_n_block, + in_mask_steps, + ) + + rP = cute.make_fragment_like(acc_S, self._dtype) + rP.store(acc_S.load().to(self._dtype)) + # /////////////////////////////////////////////////////////////////////////////// + # O gemm calculation + # /////////////////////////////////////////////////////////////////////////////// + # Convert layout of acc_S to gemm O accept layout. + # Due to the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + rP_layout_divided = cute.logical_divide(rP.layout, (None, None, 2)) + rP_mma_view = cute.make_layout( + ( + (rP_layout_divided.shape[0], rP_layout_divided.shape[2][0]), + rP_layout_divided.shape[1], + rP_layout_divided.shape[2][1], + ), + stride=( + (rP_layout_divided.stride[0], rP_layout_divided.stride[2][0]), + rP_layout_divided.stride[1], + rP_layout_divided.stride[2][1], + ), + ) + tOrS = cute.make_tensor(rP.iterator, rP_mma_view) + + # load first V k-block from smem to rmem for mma + cute.copy( + smem_copy_params.smem_tiled_copy_V, + smem_copy_params.tOsVt[None, None, 0], + smem_copy_params.tOrVt_copy_view[None, None, 0], + ) + # mma for O + for k in cutlass.range_constexpr(cute.size(tOrS.shape[2])): + # load next V k-block from smem to rmem for mma + k_next = (k + 1) % cute.size(tOrS.shape[2]) + cute.copy( + smem_copy_params.smem_tiled_copy_V, + smem_copy_params.tOsVt[None, None, k_next], + smem_copy_params.tOrVt_copy_view[None, None, k_next], + ) + cute.gemm( + mma_params.tiled_mma, + mma_params.acc_O, + tOrS[None, None, k], + mma_params.tOrVt[None, None, k], + mma_params.acc_O, + ) + + @cute.jit + def softmax_rescale_O( + self, + basic_params: SimpleNamespace, + mma_params: SimpleNamespace, + softmax_params: SimpleNamespace, + acc_S: cute.Tensor, + is_first_n_block: cutlass.Constexpr, + in_mask_steps: cutlass.Constexpr, + ): + """Apply online softmax and rescale acc_O. + + This function provides different variants for processing the first n block versus subsequent blocks, + as well as variants for handling masked and unmasked steps. + + :param basic_params: basic parameters + :type basic_params: SimpleNamespace + :param mma_params: mma parameters + :type mma_params: SimpleNamespace + :param softmax_params: softmax parameters + :type softmax_params: SimpleNamespace + :param acc_S: acc_S tensor + :type acc_S: cute.Tensor + :param is_first_n_block: is first n_block + :type is_first_n_block: cutlass.Constexpr + :param in_mask_steps: in mask steps + :type in_mask_steps: cutlass.Constexpr + """ + # Change acc_S to M,N layout view. + acc_S_mn = self._make_acc_tensor_mn_view(acc_S) + acc_O_mn = self._make_acc_tensor_mn_view(mma_params.acc_O) + row_max_prev = None + # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. + if cutlass.const_expr(not is_first_n_block): + row_max_prev = cute.make_fragment_like( + softmax_params.row_max, cutlass.Float32 + ) + cute.basic_copy(softmax_params.row_max, row_max_prev) + # if it is the first tile, create a mask for residual of S to -inf for softmax. + tScS_mn = None + if cutlass.const_expr(in_mask_steps): + mcS = cute.make_identity_tensor( + ( + basic_params.mQ.shape[0], + basic_params.mQ.shape[1], + basic_params.mQ.shape[2], + basic_params.mK.shape[1], + ) + ) + cS = cute.local_tile( + mcS[basic_params.batch_size, None, basic_params.num_head, None], + (self._m_block_size, self._n_block_size), + (basic_params.m_block, basic_params.n_block), + ) + tScS = mma_params.thr_mma.partition_C(cS) + tScS_mn = self._make_acc_tensor_mn_view(tScS) + + # Each iteration processes one row of acc_S + for r in cutlass.range_constexpr(cute.size(softmax_params.row_max)): + # mask residual of S with -inf + if cutlass.const_expr(in_mask_steps): + if cutlass.const_expr(not self._is_causal): + # traverse column index. + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + if cute.elem_less( + basic_params.mK.shape[1], tScS_mn[0, c][3] + 1 + ): + acc_S_mn[r, c] = -cutlass.Float32.inf + else: + # get the column index limit based on current row. Only consider the row index, so the column index sets to 0. + col_idx_limit = cutlass.min( + tScS_mn[r, 0][1] + 1, basic_params.mK.shape[1] + ) + # traverse column index. + for c in cutlass.range_constexpr(cute.size(tScS_mn.shape[1])): + # only consider the column index, so the row index sets to 0. + if cute.elem_less(col_idx_limit, tScS_mn[0, c][3] + 1): + acc_S_mn[r, c] = -cutlass.Float32.inf + + # (n_block_size) + acc_S_row = acc_S_mn[r, None].load() + # row_max_cur_row => f32 + row_max_cur_row = acc_S_row.reduce( + cute.ReductionOp.MAX, -cutlass.Float32.inf, 0 + ) + # quad reduction for row_max + row_max_cur_row = self._threadquad_reduce_max(row_max_cur_row) + row_max_prev_row = None + # if it is not the first tile, load the row r of previous row_max and compare with row_max_cur_row. + if cutlass.const_expr(not is_first_n_block): + row_max_prev_row = row_max_prev[r] + row_max_cur_row = cute.arch.fmax(row_max_prev_row, row_max_cur_row) + if cutlass.const_expr(self._is_causal): + row_max_cur_row = ( + 0.0 if row_max_cur_row == -cutlass.Float32.inf else row_max_cur_row + ) + + # compute exp(x - max) using exp2(x * log_2(e) - max * log_2(e)) + acc_S_row_exp = cute.math.exp2( + acc_S_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, + ) + # acc_S_row_sum => f32 + acc_S_row_sum = acc_S_row_exp.reduce( + cute.ReductionOp.ADD, cutlass.Float32.zero, 0 + ) + # if it is not the first tile, load the row r of previous row_max and minus row_max_cur_row to update row_sum. + if cutlass.const_expr(not is_first_n_block): + prev_minus_cur_exp = cute.math.exp2( + row_max_prev_row * softmax_params.softmax_scale_log2 + - row_max_cur_row * softmax_params.softmax_scale_log2, + fastmath=True, + ) + acc_S_row_sum = ( + acc_S_row_sum + softmax_params.row_sum[r] * prev_minus_cur_exp + ) + acc_O_mn[r, None] = acc_O_mn[r, None].load() * prev_minus_cur_exp + # update row_max, row_sum and acc_S + softmax_params.row_max[r] = row_max_cur_row + softmax_params.row_sum[r] = acc_S_row_sum + acc_S_mn[r, None] = acc_S_row_exp + + @cute.jit + def normalize_softmax( + self, + acc_O: cute.Tensor, + row_sum: cute.Tensor, + ): + """Normalize acc_O by row_sum. + + :param acc_O: input tensor + :type acc_O: cute.Tensor + :param row_sum: row_sum tensor + :type row_sum: cute.Tensor + """ + # do quad reduction for row_sum. + acc_O_mn = self._make_acc_tensor_mn_view(acc_O) + for r in cutlass.range_constexpr(cute.size(row_sum)): + row_sum[r] = self._threadquad_reduce_sum(row_sum[r]) + # if row_sum is zero or nan, set acc_O_mn_row to 1.0 + acc_O_mn_row_is_zero_or_nan = row_sum[r] == 0.0 or row_sum[r] != row_sum[r] + + scale = ( + 1.0 if acc_O_mn_row_is_zero_or_nan else cute.arch.rcp_approx(row_sum[r]) + ) + + acc_O_mn[r, None] = acc_O_mn[r, None].load() * scale + + def _make_acc_tensor_mn_view(self, acc: cute.Tensor) -> cute.Tensor: + """make acc tensor as mn layout view + + :param acc: input tensor + :type acc: cute.Tensor + :return: acc tensor mn layout view + :rtype: cute.Tensor + """ + acc_layout_col_major = cute.make_layout(acc.layout.shape) + acc_layout_mn = cute.make_layout( + ( + ( + acc_layout_col_major.shape[0][1], + acc_layout_col_major.shape[1], + ), # MMA_M + ( + acc_layout_col_major.shape[0][0], + acc_layout_col_major.shape[2], + ), # MMA_N + ), + stride=( + ( + acc_layout_col_major.stride[0][1], + acc_layout_col_major.stride[1], + ), # MMA_M + ( + acc_layout_col_major.stride[0][0], + acc_layout_col_major.stride[2], + ), # MMA_N + ), + ) + acc_layout_mn = cute.composition(acc.layout, acc_layout_mn) + return cute.make_tensor(acc.iterator, acc_layout_mn) + + def _threadquad_reduce(self, val: cutlass.Float32, op: Callable) -> cutlass.Float32: + """thread quad reduction + + :param val: register value + :type val: cutlass.Float32 + :param op: binary operator + :type op: Callable + :return: reduced value + :rtype: cutlass.Float32 + """ + val = op( + val, + cute.arch.shuffle_sync_bfly(val, offset=2, mask=-1, mask_and_clamp=31), + ) + val = op( + val, + cute.arch.shuffle_sync_bfly(val, offset=1, mask=-1, mask_and_clamp=31), + ) + return val + + def _threadquad_reduce_max(self, val: cutlass.Float32) -> cutlass.Float32: + """thread quad reduction max + + :param val: register value + :type val: cutlass.Float32 + :return: max value + :rtype: cutlass.Float32 + """ + return self._threadquad_reduce(val, lambda x, y: cute.arch.fmax(x, y)) + + def _threadquad_reduce_sum(self, val: cutlass.Float32) -> cutlass.Float32: + """thread quad reduction sum + + :param val: register value + :type val: cutlass.Float32 + :return: sum value + :rtype: cutlass.Float32 + """ + return self._threadquad_reduce(val, lambda x, y: x + y) + + +def run( + dtype: Type[cutlass.Numeric], + batch_size: int, + seqlen_q: int, + seqlen_k: int, + num_head: int, + head_dim: int, + softmax_scale: float = 1.0, + m_block_size: int = 128, + n_block_size: int = 128, + num_threads: int = 128, + is_causal: bool = False, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + # Skip unsupported testcase + if not FlashAttentionForwardAmpere.can_implement( + dtype, + head_dim, + m_block_size, + n_block_size, + num_threads, + is_causal, + ): + raise TypeError( + f"Unsupported testcase {dtype}, {head_dim}, {m_block_size}, {n_block_size}, {num_threads}, {is_causal}" + ) + + print(f"Running Ampere SM80 FlashAttentionForward test with:") + print(f" dtype: {dtype}") + print(f" batch_size: {batch_size}") + print(f" seqlen_q: {seqlen_q}") + print(f" seqlen_k: {seqlen_k}") + print(f" num_head: {num_head}") + print(f" head_dim: {head_dim}") + print(f" softmax_scale: {softmax_scale}") + print(f" m_block_size: {m_block_size}") + print(f" n_block_size: {n_block_size}") + print(f" num_threads: {num_threads}") + print(f" is_causal: {is_causal}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Create tensor Q/K/V/O + def create_tensor( + batch_size: int, + seqlen: int, + num_head: int, + head_dim: int, + dtype: Type[cutlass.Numeric], + ) -> cute.Tensor: + # (batch_size, seqlen, num_head, head_dim) + shape = (batch_size, seqlen, num_head, head_dim) + torch_tensor = ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=cutlass_torch.dtype(dtype)) + .cuda() + ) + # assume input is 16B aligned. + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=3) + .mark_compact_shape_dynamic( + mode=3, + stride_order=torch_tensor.dim_order(), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor + + q, q_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k, k_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v, v_torch = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o, o_torch = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + + fa2_fwd = FlashAttentionForwardAmpere( + head_dim, + m_block_size, + n_block_size, + num_threads, + is_causal, + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # compile the fa2 forward pass + compiled_fa2_fwd = cute.compile(fa2_fwd, q, k, v, o, softmax_scale, current_stream) + + if not skip_ref_check: + compiled_fa2_fwd(q, k, v, o, softmax_scale, current_stream) + torch.cuda.synchronize() + q_ref = q_torch.permute(0, 2, 1, 3) + k_ref = k_torch.permute(0, 2, 1, 3) + v_ref = v_torch.permute(0, 2, 1, 3) + torch.backends.cuda.enable_flash_sdp(enabled=True) + ref_o = torch.nn.functional.scaled_dot_product_attention( + q_ref, k_ref, v_ref, scale=softmax_scale, is_causal=is_causal + ).permute(0, 2, 1, 3) + torch.testing.assert_close(o_torch.cpu(), ref_o.cpu(), atol=1e-02, rtol=1e-04) + print("Results verified successfully!") + + def generate_tensors(): + q_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + k_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + v_workspace, _ = create_tensor(batch_size, seqlen_k, num_head, head_dim, dtype) + o_workspace, _ = create_tensor(batch_size, seqlen_q, num_head, head_dim, dtype) + return testing.JitArguments( + q_workspace, + k_workspace, + v_workspace, + o_workspace, + softmax_scale, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + q_torch.numel() * q_torch.element_size() + + k_torch.numel() * k_torch.element_size() + + v_torch.numel() * v_torch.element_size() + + o_torch.numel() * o_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_fa2_fwd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return avg_time_us # Return execution time in microseconds + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="example of flash attention v2 with CuTe on GPU" + ) + parser.add_argument("--dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--batch_size", type=int, default=4) + parser.add_argument("--seqlen_q", type=int, default=8192) + parser.add_argument("--seqlen_k", type=int, default=8192) + parser.add_argument("--num_head", type=int, default=16) + parser.add_argument("--head_dim", type=int, default=128) + parser.add_argument("--softmax_scale", type=float, default=0.5) + parser.add_argument("--m_block_size", type=int, default=128) + parser.add_argument("--n_block_size", type=int, default=64) + parser.add_argument("--num_threads", type=int, default=128) + parser.add_argument("--is_causal", action="store_true", help="Enable causal mask") + parser.add_argument("--warmup_iterations", type=int, default=3) + parser.add_argument("--iterations", type=int, default=10) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference check" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + run( + args.dtype, + args.batch_size, + args.seqlen_q, + args.seqlen_k, + args.num_head, + args.head_dim, + args.softmax_scale, + args.m_block_size, + args.n_block_size, + args.num_threads, + args.is_causal, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py new file mode 100644 index 0000000000000000000000000000000000000000..e7722a8d62a7529387dfbf738292b545738eb28a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/sgemm.py @@ -0,0 +1,866 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import time +from typing import Tuple + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +from cutlass.cute.runtime import from_dlpack + +""" +A dense FP32 SIMT GEMM (C = A * B) example using CUTE DSL. +- Matrix A is MxK, A can be row-major("K") or column-major("M") +- Matrix B is NxK, B can be row-major("N") or column-major("K") +- Matrix C is MxN, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes FPU for matrix multiply-accumulate (MMA) operations + - Use multistage pipeline to overlap computation and memory access + * Shared memory pipeline: hides gmem-to-smem latency. + * Register pipeline: overlaps shared memory-to-register transfers with + computations and eliminates false data dependencies for + better parallelism. + - Use vectorized copies + - Add padding to reduce bank conflicts in global -> shared memory copies + - Use predication to avoid unnecessary copies or copies of stale data + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. +2. Perform matrix multiply-accumulate (MMA) operations using simple fused multiply-add atomics. +3. Store results from registers (RMEM) to global memory (GMEM). + +To run this example: + +.. code-block:: bash + + python examples/ampere/sgemm.py \ + --mnk 8192,8192,8192 \ + --a_major m --b_major n --c_major n + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/sgemm.py \ + --mnk 8192,8192,8192 \ + --a_major m --b_major n --c_major n \ + --skip_ref_check --iterations 2 + +Constraints: +* Supported input, output, and accumulator data types: fp32 +* Default tile shape is set to be 128x128x8 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class SGemm: + def __init__( + self, + cta_tiler: Tuple[int, int, int] = (128, 128, 8), + num_stages: int = 3, + num_threads: int = 256, + ): + self._cta_tiler = cta_tiler + self._num_stages = num_stages + self._num_threads = num_threads + assert num_threads > 0, "needs at least one thread" + assert num_threads % 16 == 0, "multiples of 16 required for MMA thread layout" + + self._bM, self._bN, self._bK = self._cta_tiler + assert self._bM % 16 == 0, "multiple of 16 required for tile dimension M" + assert self._bN % 16 == 0, "multiple of 16 required for tile dimension N" + assert self._num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + stream: cuda.CUstream = cuda.CUstream(cuda.CUstream_flags.CU_STREAM_DEFAULT), + ): + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Create layouts for shared memory for A and B: + # - sA/sB is m/n-major to vectorized copies from shared + # memory to registers. This is because the MMA layouts + # for sA/sB are also m/n-major + # - When gA/gB is k-major, pad 4 elements to reduce bank conflicts + # /////////////////////////////////////////////////////////////////////////////// + + padding_a = 4 if self.a_major_mode == utils.LayoutEnum.ROW_MAJOR else 0 + padding_b = 4 if self.b_major_mode == utils.LayoutEnum.ROW_MAJOR else 0 + sA_layout = cute.make_layout( + (self._bM, self._bK, self._num_stages), + stride=(1, (self._bM + padding_a), self._bK * (self._bM + padding_a)), + ) + sB_layout = cute.make_layout( + (self._bN, self._bK, self._num_stages), + stride=(1, (self._bN + padding_b), self._bK * (self._bN + padding_b)), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Create copy layouts that will be used for asynchronous + # global memory -> shared memory copies: + # - The majorness of tA/tB follows the majorness of gA/gB + # - For k-major, these layouts will copy values one-by-one from + # from global memory, without vectorizing + # - For m/n-major, it will vectorize to a 128bit copy for faster + # data transfer between global and shared memory, as long + # as the alignment of the tensor allows it. Otherwise, it + # defaults to a non-vectorized copy + # /////////////////////////////////////////////////////////////////////////////// + + tA = cute.make_layout( + (self._num_threads // self._bK, self._bK), stride=(self._bK, 1) + ) + tB = cute.make_layout( + (self._num_threads // self._bK, self._bK), stride=(self._bK, 1) + ) + vA = cute.make_layout((1, 1)) + vB = cute.make_layout((1, 1)) + atom_async_copy_A = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mA.element_type.width, + ) + atom_async_copy_B = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mB.element_type.width, + ) + + if cutlass.const_expr(self.a_major_mode == utils.LayoutEnum.COL_MAJOR): + num_vectorized = 4 if (mA.layout.max_alignment % 16 == 0) else 1 + atom_async_copy_A = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mA.element_type.width * num_vectorized, + ) + major_mode_size = self._bM // num_vectorized + tA = cute.make_layout( + (major_mode_size, self._num_threads // major_mode_size), + stride=(1, major_mode_size), + ) + vA = cute.make_layout((num_vectorized, 1)) + + if cutlass.const_expr(self.b_major_mode == utils.LayoutEnum.COL_MAJOR): + num_vectorized = 4 if (mB.layout.max_alignment % 16 == 0) else 1 + atom_async_copy_B = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mA.element_type, + num_bits_per_copy=mB.element_type.width * num_vectorized, + ) + major_mode_size = self._bN // num_vectorized + tB = cute.make_layout( + (major_mode_size, self._num_threads // major_mode_size), + stride=(1, major_mode_size), + ) + vB = cute.make_layout((num_vectorized, 1)) + + tiled_copy_A = cute.make_tiled_copy_tv(atom_async_copy_A, tA, vA) + tiled_copy_B = cute.make_tiled_copy_tv(atom_async_copy_B, tB, vB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create layouts for GEMM: + # We tile an MMA atom across a tensor. `atoms_layout` is the layout + # of atoms in the tiled MMA. (Because we use an `MmaUniversalOp`, + # which has a trivial 1x1x1 MMA trait, `atoms_layout` is also + # simply the thread layout for C.) `permutation_tiler` reorders the + # elements of the tensor that the tiled MMA is applied to. + # Different combinations of `atoms_layout` and `permutation_tiler` + # values can create different MMA thread-value patterns. + # + # Here, the MMA layout is set so that each thread copies four + # consecutive elements from shared memory to registers. + # `permutation_tiler_M/N` maps the elements handled by each thread + # to the permuted element in the tensor. + # For increasing indices in the tensor, the thread ID that reads it is: + # - (without permutation) ==> + # 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 0 1 2 ... 15 ...... + # - (with permutation) ==> + # 0 0 0 0 1 1 1 1 2 2 2 2 ... 15 15 15 15 0 0 0 0 1 1 1 1 ...... + # /////////////////////////////////////////////////////////////////////////////// + atoms_layout = cute.make_layout( + (self._num_threads // 16, 16, 1), stride=(16, 1, 0) + ) + if cutlass.const_expr(self.c_major_mode == utils.LayoutEnum.COL_MAJOR): + atoms_layout = cute.make_layout( + (16, self._num_threads // 16, 1), stride=(1, 16, 0) + ) + op = cute.nvgpu.MmaUniversalOp(cutlass.Float32) + permutation_tiler_M = cute.make_layout( + (atoms_layout.shape[0], 4), stride=(4, 1) + ) + permutation_tiler_N = cute.make_layout( + (atoms_layout.shape[1], 4), stride=(4, 1) + ) + tiled_mma = cute.make_tiled_mma( + op, + atoms_layout, + permutation_mnk=(permutation_tiler_M, permutation_tiler_N, None), + ) + + # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, 1) + grid_dim = *cute.ceil_div(mC.shape, (self._bM, self._bN)), 1 + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + tiled_copy_A, + tiled_copy_B, + tiled_mma, + epilogue_op, + ).launch( + grid=grid_dim, + block=[cute.size(atoms_layout), 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.Layout, + sB_layout: cute.Layout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_mma: cute.TiledMma, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread and block indices + tidx, tidy, tidz = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + tiler_coord = (bidx, bidy, None) + thr_mma = tiled_mma.get_slice(tidx) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_K, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1) + ) + gB = cute.local_tile( + mB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1) + ) + gC = cute.local_tile( + mC, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, 1, None) + ) + + # Move the pointer of gA/gB in the `-k`` direction, making the first + # tile (instead of the last one) irregular in shape when k is irregular. + # We first handle the irregular tile to avoid checking for this + # condition within the mainloop. + residue_k = mA.shape[1] - cutlass.Int32(self._bK) * gA.shape[2] + gA = cute.domain_offset((0, residue_k, 0), gA) + gB = cute.domain_offset((0, residue_k, 0), gB) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffer + smem = cutlass.utils.SmemAllocator() + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when the problem shape + # isn't a multiple of the tile shape. If tApA/B[i] is 0, then do not + # do the copy atom associated with index i. + # cA: (BLK_M, BLK_K) => (blk_m, blk_k) + # cB: (BLK_N, BLK_K) => (blk_n, blk_k) + # tAcA: (CPY, CPY_M, CPY_K) => (blk_m, blk_k) + # tBcB: (CPY, CPY_N, CPY_K) => (blk_n, blk_k) + # tApA: (rest_v, CPY_M, CPY_K), stride=(..., ..., 0) + # tBpB: (rest_v, CPY_N, CPY_K), stride=(..., ..., 0) + # CPY = (atom_v, rest_v) + # /////////////////////////////////////////////////////////////////////////////// + # Construct identity layout for sA and sB, used for predication + mcA = cute.make_identity_tensor(mA.shape) + mcB = cute.make_identity_tensor(mB.shape) + cA = cute.local_tile( + mcA, tiler=self._cta_tiler, coord=tiler_coord, proj=(1, None, 1) + ) + cB = cute.local_tile( + mcB, tiler=self._cta_tiler, coord=tiler_coord, proj=(None, 1, 1) + ) + cA = cute.domain_offset((0, residue_k, 0), cA) + cB = cute.domain_offset((0, residue_k, 0), cB) + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + # Allocate predicate tensors for m and n + tApA = cute.make_fragment( + cute.make_layout( + ( + tAsA.shape[0][1], + cute.size(tAsA, mode=[1]), + cute.size(tAsA, mode=[2]), + ), + stride=(cute.size(tAsA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Allocate predicate tensors for m, n and k for residue k-tile + tApA_residue_k = cute.make_fragment( + cute.make_layout( + ( + tAsA.shape[0][1], + cute.size(tAsA, mode=[1]), + cute.size(tAsA, mode=[2]), + ), + stride=( + cute.size(tAsA, mode=[1]) * cute.size(tAsA, mode=[2]), + cute.size(tAsA, mode=[2]), + 1, + ), + ), + cutlass.Boolean, + ) + tBpB_residue_k = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=( + cute.size(tBsB, mode=[1]) * cute.size(tBsB, mode=[2]), + cute.size(tBsB, mode=[2]), + 1, + ), + ), + cutlass.Boolean, + ) + # Set predicates for m/n bounds for mainloop + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # Set predicates for m/n/k bounds for residue k tile + for rest_v in range(tApA_residue_k.shape[0]): + for m in range(tApA_residue_k.shape[1]): + for k in range(tApA_residue_k.shape[2]): + coord_A = tAcA[(0, rest_v), m, k, 0] + tApA_residue_k[rest_v, m, k] = cute.elem_less( + (coord_A[0], cutlass.Int32(-1)), (mA.shape[0], coord_A[1]) + ) + for rest_v in range(tBpB_residue_k.shape[0]): + for n in range(tBpB_residue_k.shape[1]): + for k in range(tBpB_residue_k.shape[2]): + coord_B = tBcB[(0, rest_v), n, k, 0] + tBpB_residue_k[rest_v, n, k] = cute.elem_less( + (coord_B[0], cutlass.Int32(-1)), (mB.shape[0], coord_B[1]) + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Start async loads for 0th k-tile, where we take care of the k-residue + k_pipe_max = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + gmem_pipe_read = cutlass.Int32(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, 0], + pred=tApA_residue_k, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, 0], + pred=tBpB_residue_k, + ) + cute.arch.cp_async_commit_group() + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(0) + ) + # Start async loads for 1st k-tile onwards, no k-residue handling needed + for k_tile in range(1, k_pipe_max - 1): + if k_tile < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(0) + ) + cute.arch.cp_async_commit_group() + + # all tiles have been copied from global memory, so clear the + # predicate tensor + if k_tile_count < k_pipe_max: + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cutlass.Boolean(0) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cutlass.Boolean(0) + + # /////////////////////////////////////////////////////////////////////////////// + # Define A/B partitioning and C accumulators. + # /////////////////////////////////////////////////////////////////////////////// + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) + + # Current pipe index in smem to read from / write to + smem_pipe_read = cutlass.Int32(0) + smem_pipe_write = cutlass.Int32(k_pipe_max - 1) + + tCsA_p = tCsA[None, None, None, smem_pipe_read] + tCsB_p = tCsB[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + k_block_max = cute.size(tCrA, mode=[2]) + + if k_block_max > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(k_pipe_max - 2) + cute.arch.barrier() + # Prefetch the first rmem from the first k-tile + cute.autovec_copy(tCsA_p[None, None, 0], tCrA[None, None, 0]) + cute.autovec_copy(tCsB_p[None, None, 0], tCrB[None, None, 0]) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + + for _ in range(k_tile_count): + for k_block in range(k_block_max, unroll_full=True): + if k_block == k_block_max - 1: + tCsA_p = tCsA[None, None, None, smem_pipe_read] + tCsB_p = tCsB[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(k_pipe_max - 2) + cute.arch.barrier() + + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % k_block_max # static + cute.autovec_copy( + tCsA_p[None, None, k_block_next], + tCrA[None, None, k_block_next], + ) + cute.autovec_copy( + tCsB_p[None, None, k_block_next], + tCrB[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and + # compute instructions, we intentionally use the sequence: + # copy A, perform GEMM, then copy B. + if k_block == 0: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, gmem_pipe_read], + tAsA[None, None, None, smem_pipe_write], + # Use predicates because the m-mode may be irregular + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, gmem_pipe_read], + tBsB[None, None, None, smem_pipe_write], + # Use predicates because the n-mode may be irregular + pred=tBpB, + ) + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == k_pipe_max: + smem_pipe_read = cutlass.Int32(0) + # After copying all tiles, we avoid clearing the predicate + # tensor in the `mainloop` to prevent increasing its + # instruction count. Instead, we continue copying the + # first tile, though it won't be used. The 0-th tile is not + # copied due to its irregular shape, which could lead to + # illegal memory accesses. + gmem_pipe_read = ( + gmem_pipe_read + 1 + if gmem_pipe_read + 1 < k_tile_count + else cutlass.Int32(1) + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # Applies the epilogue operation to the accumulated results and copies + # them without vectorization. + # /////////////////////////////////////////////////////////////////////////////// + cute.arch.cp_async_wait_group(0) + cute.arch.barrier() + tCrC.store(epilogue_op(tCrC.load())) + + # predicate + cC = cute.make_identity_tensor(gC.shape) + tCpC = thr_mma.partition_C(cC) + predC = cute.make_fragment(tCrC.layout, cutlass.Boolean) + residue_m = mC.shape[0] - cutlass.Int32(self._bM) * bidx + residue_n = mC.shape[1] - cutlass.Int32(self._bN) * bidy + for i in range(cute.size(tCrC.shape)): + predC[i] = cute.elem_less(tCpC[i], (residue_m, residue_n)) + numIterM = cute.size(tCrC, mode=[1]) + numIterN = cute.size(tCrC, mode=[2]) + atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mC.element_type) + cute.copy(atom, tCrC, tCgC, pred=predC) + return + + +def run( + mnk: Tuple[int, int, int], + a_major: str, + b_major: str, + c_major: str, + static_shape: bool = False, + warmup_iterations: int = 2, + iterations: int = 100, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute SIMT GEMM operation and benchmark performance. + + :param mnk: GEMM problem size (M, N, K, L) + :type mnk: Tuple[int, int, int, int] + :param a_major: Memory layout of tensor A + :type a_major: str + :param b_major: Memory layout of tensor B + :type b_major: str + :param c_major: Memory layout of tensor C + :type c_major: str + :param static_shape: Whether to use static shape optimization, defaults to False + :type static_shape: bool, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 2 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 100 + :type iterations: int, optional + :param skip_ref_check: Skip validation against reference implementation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print(f"Running Ampere SIMT GEMM example:") + print(f"mnk: {mnk}") + print(f"A major: {a_major}, B major: {b_major}, C major: {c_major}") + print(f"Static shape: {static_shape}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K = mnk + + # Create and permute tensor A/B/C + def create_and_permute_tensor(mode0, mode1, is_mode0_major, dtype): + # is_mode0_major: (mode1, mode0) -> (mode0, mode1) + # else: (mode0, mode1) -> (mode0, mode1) + shape = (mode1, mode0) if is_mode0_major else (mode0, mode1) + permute_order = (1, 0) if is_mode0_major else (0, 1) + + return ( + torch.empty(*shape, dtype=torch.int32) + .random_(-5, 5) + .to(dtype=dtype) + .permute(permute_order) + .cuda() + ) + + a = create_and_permute_tensor(M, K, a_major == "m", torch.float32) + b = create_and_permute_tensor(N, K, b_major == "n", torch.float32) + c = create_and_permute_tensor(M, N, c_major == "m", torch.float32) + + divisibility_a = a.shape[1] if a_major == "k" else a.shape[0] + divisibility_b = b.shape[1] if b_major == "k" else b.shape[0] + divisibility_c = c.shape[1] if c_major == "n" else c.shape[0] + + a_tensor = ( + from_dlpack(a, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) + ) + + b_tensor = ( + from_dlpack(b, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if b_major == "k" else 0), + divisibility=divisibility_b, + ) + ) + + c_tensor = ( + from_dlpack(c, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if c_major == "n" else 0), + divisibility=divisibility_c, + ) + ) + + sgemm = SGemm() + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + compiled_fn = cute.compile( + sgemm, + a_tensor, + b_tensor, + c_tensor, + stream=current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + print("Executing GEMM kernel...") + + if not skip_ref_check: + compiled_fn(a_tensor, b_tensor, c_tensor) + torch.cuda.synchronize() + print("Verifying results...") + ref = torch.einsum("mk,nk->mn", a, b) + torch.testing.assert_close(c.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + # Create new tensors for each workspace to ensure cold L2 cache + a_workspace = create_and_permute_tensor(M, K, a_major == "m", torch.float32) + b_workspace = create_and_permute_tensor(N, K, b_major == "n", torch.float32) + c_workspace = create_and_permute_tensor(M, N, c_major == "m", torch.float32) + + if static_shape: + a_tensor_workspace = ( + from_dlpack(a_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if a_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if a_major == "k" else 0), + divisibility=divisibility_a, + ) + ) + else: + a_tensor_workspace = from_dlpack(a_workspace, assumed_align=16) + + b_tensor_workspace = ( + from_dlpack(b_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if b_major == "k" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if b_major == "k" else 0), + divisibility=divisibility_b, + ) + ) + + c_tensor_workspace = ( + from_dlpack(c_workspace, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + .mark_compact_shape_dynamic( + mode=(1 if c_major == "n" else 0), + divisibility=divisibility_c, + ) + ) + + return testing.JitArguments( + a_tensor_workspace, b_tensor_workspace, c_tensor_workspace, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a.numel() * a.element_size() + + b.numel() * b.element_size() + + c.numel() * c.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_fn, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + # Print execution results + print(f"Kernel execution time: {avg_time_us / 1e3:.4f} ms") + + return avg_time_us # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser() + parser.add_argument( + "--mnk", type=parse_comma_separated_ints, default=(256, 256, 64) + ) + parser.add_argument("--a_major", choices=["k", "m"], default="k") + parser.add_argument("--b_major", choices=["k", "n"], default="k") + parser.add_argument("--c_major", choices=["n", "m"], default="n") + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--static_shape", action="store_true") + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + print("Running SIMT GEMM example:") + + torch.manual_seed(1024) + + run( + args.mnk, + args.a_major, + args.b_major, + args.c_major, + args.static_shape, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f5c1e03f78c1f1656e06c34e4d9a5007af6b64 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/smem_allocator.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import cutlass.cute as cute +import cutlass +import torch +import numpy as np +from cutlass.cute.runtime import from_dlpack + +""" +A Shared Memory Allocator Example on NVIDIA Ampere architecture using CuTe DSL. + +This example demonstrates how to allocate and manage shared memory in JIT kernels by using the SmemAllocator in CuTe DSL. +It shows various ways to allocate different data structures in shared memory: + +1. Struct allocation with natural and strict alignment +2. Raw memory block allocation with custom alignment +3. Array allocation with automatic alignment +4. Tensor allocation with layout specification + +The example includes: +- Shared storage struct with mixed alignment requirements +- Memory allocation patterns for different data types +- Tensor operations on allocated memory + +To run this example: + +.. code-block:: bash + + python examples/ampere/smem_allocator.py + +The example will allocate shared memory, perform tensor operations, and verify the results. +""" + + +@cute.struct +class complex: + real: cutlass.Float32 + imag: cutlass.Float32 + + +# SharedStorage size is 512, alignment is 128 +@cute.struct +class SharedStorage: + # struct elements with natural alignment + a: cute.struct.MemRange[cutlass.Float32, 32] # array + b: cutlass.Int64 # saclar + c: complex # nested struct + # struct elements with strict alignment + x: cute.struct.Align[ + cute.struct.MemRange[cutlass.Float32, 32], + 128, + ] + y: cute.struct.Align[cutlass.Int32, 8] + z: cute.struct.Align[complex, 16] + + +@cute.kernel +def kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # Note: SMEM_SIZE bytes (specified in kernel().launch(smem=...)) can be reserved for developer to utilize + # Note: alignment of inital allocator base ptr is 1024 + allocator = cutlass.utils.SmemAllocator() + # base ptr of allocator points at: SMEM_ADDR_START (the starting address of available shared memory) + + # -- Allocate a struct -- + # Note: when specified alignment, max(alignment, alignof(struct)) will be applied + # reserves the section of struct in smem, elements in the struct can be accessed by ptr + struct_in_smem = allocator.allocate(SharedStorage) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_STRUCT = SMEM_ADDR_START + aligned_size(struct) + + # -- Allocate a block of memory -- + # reserves a section of 64 bytes in smem, align to 128 bytes, returns the section base ptr + section_in_smem = allocator.allocate(64, byte_alignment=128) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_SECTION = SMEM_ADDR_AFTER_STRUCT + aligned_size(section) + + # -- Allocate an array -- + # reserves an int64 array of size 14 in smem, returns the array base ptr + array_in_smem = allocator.allocate_array(element_type=cutlass.Int64, num_elems=14) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_ARRAY = SMEM_ADDR_AFTER_SECTION + aligned_size(array) + + # -- Allocate a tensor -- + # Note: use cute.ComposedLayout or cute.Layout to specify layout of tensor + # Note: iterator swizzle with swizzle layout is currently not supported + layout = cute.make_layout((16, 2)) + tensor_in_smem = allocator.allocate_tensor( + element_type=cutlass.Float32, layout=layout, byte_alignment=32, swizzle=None + ) + # base ptr of allocator now points at: SMEM_ADDR_AFTER_TENSOR = SMEM_ADDR_AFTER_ARRAY + aligned_size(tensor) + + # ptr> + # ptr> + # ptr> + print(struct_in_smem.a.data_ptr()) + print(struct_in_smem.b) + print(struct_in_smem.c.real) + # ptr> + print(section_in_smem) + # ptr> + print(array_in_smem) + # tensor> o (16,4):(1,16)> + print(tensor_in_smem) + + # fill MemRange tensor in struct and copy to dst + a_tensor = struct_in_smem.a.get_tensor(cute.make_layout((8, 4))) + a_tensor.fill(const_a) + cute.printf("cute.struct.MemRange: {}", a_tensor) + dst_a.store(a_tensor.load()) + + # convert block of smem to fill tensor and copy to dst + layout = cute.make_layout((8, 2)) + sec_ptr = cute.recast_ptr(section_in_smem, dtype=cutlass.Float32) + sec_tensor = cute.make_tensor(sec_ptr, layout) + sec_tensor.fill(const_b) + cute.printf("block of memory: {}", sec_tensor) + dst_b.store(sec_tensor.load()) + + # fill allocated tensor in smem and copy to dst + tensor_in_smem.fill(const_c) + cute.printf("tensor in smem: {}", tensor_in_smem) + dst_c.store(tensor_in_smem.load()) + + +@cute.jit +def run_allocation_kernel( + const_a: cutlass.Constexpr, + dst_a: cute.Tensor, + const_b: cutlass.Constexpr, + dst_b: cute.Tensor, + const_c: cutlass.Constexpr, + dst_c: cute.Tensor, +): + # additional size for the example, 64(section) + 112(array) + 128(tensor) < 384 + addtional_bytes = 384 + # Note: launch shared memory size is: SMEM_SIZE = 512 + 384 = 896 bytes + kernel(const_a, dst_a, const_b, dst_b, const_c, dst_c).launch( + grid=(1, 1, 1), + block=(1, 1, 1), + smem=SharedStorage.size_in_bytes() + addtional_bytes, + ) + + +def veify_allocation_kernel(const_a, const_b, const_c): + dst_a = torch.zeros((8, 4), dtype=torch.float32, device="cuda") + dst_b = torch.zeros((8, 2), dtype=torch.float32, device="cuda") + dst_c = torch.zeros((16, 2), dtype=torch.float32, device="cuda") + + run_allocation_kernel( + const_a, + from_dlpack(dst_a), + const_b, + from_dlpack(dst_b), + const_c, + from_dlpack(dst_c), + ) + + np.testing.assert_equal(const_a, dst_a.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_b, dst_b.detach().cpu().numpy()[0]) + np.testing.assert_equal(const_c, dst_c.detach().cpu().numpy()[0]) + + +if __name__ == "__main__": + # prepare cuda context + cutlass.cuda.initialize_cuda_context() + # An example for shared memory allocation + const_a = 0.5 + const_b = 1.0 + const_c = 2.0 + veify_allocation_kernel(const_a, const_b, const_c) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..413d1fbf0a239f8e2655ef5a530d7aa9bd6b2988 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/ampere/tensorop_gemm.py @@ -0,0 +1,1012 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import math +import time +from typing import Tuple, Type + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +from cutlass.cute.runtime import from_dlpack + +""" +A dense GEMM (C = A * B) example for the NVIDIA Ampere architecture using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Ampere's tensor cores for matrix multiply-accumulate (MMA) operations + - Threadblock rasterization to improve data re-use + - Supports multi-stage pipeline to overlap computation and memory access + - Implements shared memory buffering for epilogue to increase coalesced global memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using asynchronous copies. +2. Perform matrix multiply-accumulate (MMA) operations. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM). + +The Ampere tensor core instruction used operates as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n + +The above example command computes with M=8192, N=8192, K=8192, +batch_count=1. The atom layout's shape is 2x2x1 and the input, mma +accumulator, and output data type are set as fp16, fp32 and fp16, +respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/ampere/tensorop_gemm.py \ + --mnkl 8192,8192,8192,1 --atom_layout_mnk 2,2,1 \ + --ab_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major m --b_major n --c_major n \ + --skip_ref_check --iterations 2 + +Constraints: +* Supported input and output data types: fp16 +* Support accumulator data types: f32 +* Default tile shape is set to be 128x128x32 +* Atom layout's MNK shape is set so that tile shape can be divided by MMA + instruction shape +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8 +""" + + +class TensorOpGemm: + def __init__( + self, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + atom_layout_mnk: Tuple[int, int, int], + ): + self.ab_dtype = ab_dtype + self.c_dtype = c_dtype + self.acc_dtype = acc_dtype + self.cta_tiler = (128, 128, 32) + self.num_stages = 3 + self.atom_layout_mnk = atom_layout_mnk + atom_lay_M, atom_lay_N, atom_lay_K = self.atom_layout_mnk + self.num_threads = atom_lay_M * atom_lay_N * atom_lay_K * 32 + + self.bM, self.bN, self.bK = self.cta_tiler + self.mma_inst_shape = (16, 8, 16) + mmaM, mmaN, mmaK = self.mma_inst_shape + + assert ( + self.bM % (atom_lay_M * mmaM) == 0 + ), "bM must be divisible by MMA instruction" + assert ( + self.bN % (atom_lay_N * mmaN) == 0 + ), "bN must be divisible by MMA instruction" + assert atom_lay_K == 1, "this example does not support atom layout K > 1" + assert self.bK % mmaK == 0, "bK must be divisible by MMA instruction" + assert self.num_stages >= 3, "num_stages must be greater than or equal to 3" + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # The grid divides the problems's M, N, and L dimensions by the + # respective modes of the tile shape (bM, bN, 1). The K dimension is + # handled within a block via a multistage process. + + self.a_major_mode = utils.LayoutEnum.from_tensor(mA) + self.b_major_mode = utils.LayoutEnum.from_tensor(mB) + self.c_major_mode = utils.LayoutEnum.from_tensor(mC) + + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory layout: + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a layout with the size required for the provided tile + # size and num stages (stages are used for K dimension) that is also + # sectioned into 64x8 or 8x32 layout atoms. The swizzle is set so that + # the atom for the shared memory -> register copy does not encounter + # bank conflicts + + # assume the input is 16B align + ab_copy_bits = 128 + sA_layout = self._make_smem_layout_AB( + mA.element_type, + self.a_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[2], self.num_stages), + ) + sB_layout = self._make_smem_layout_AB( + mB.element_type, + self.b_major_mode, + ab_copy_bits, + (self.cta_tiler[1], self.cta_tiler[2], self.num_stages), + ) + + # Creates a similar layout but without num_stages or layout atoms + sC_layout = self._make_smem_layout_C( + mC.element_type, + self.c_major_mode, + ab_copy_bits, + (self.cta_tiler[0], self.cta_tiler[1]), + ) + + # Shared memory allocated for operations with A, B will be + # overwritten for operations on C. This is to improve performance + # by reducing the size of shared memory requested by each block + smem_size = max( + cute.size_in_bytes(mC.element_type, sC_layout), + cute.size_in_bytes(mA.element_type, sA_layout) + + cute.size_in_bytes(mB.element_type, sB_layout), + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled copy: + # The majorness of tA/tB/tC follows the majorness of gA/gB/gC, + # enabling merged accesses to global memory for faster data + # transfer between global and shared memory. + # /////////////////////////////////////////////////////////////////////////////// + + # Create a copy atom for a global to shared memory asynchronous copy + atom_async_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp( + cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL + ), + mA.element_type, + num_bits_per_copy=ab_copy_bits, + ) + + # Create thread layouts for tiled copy from the copy atom where the + # thread layout simply follows the leading dimension of the tensor + tiled_copy_A = self._make_gmem_tiled_copy_AB( + atom_async_copy, mA.element_type, self.a_major_mode, ab_copy_bits + ) + tiled_copy_B = self._make_gmem_tiled_copy_AB( + atom_async_copy, mB.element_type, self.b_major_mode, ab_copy_bits + ) + + # Creates a synchronous copy atom and thread layouts for the epilogue + c_copy_bits = 128 + atom_sync_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + mC.element_type, + num_bits_per_copy=c_copy_bits, + ) + tiled_copy_C = self._make_gmem_tiled_copy_C( + atom_sync_copy, mC.element_type, self.c_major_mode, c_copy_bits + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Tiled MMA + # /////////////////////////////////////////////////////////////////////////////// + + # Creates a mma atom with 16x8x16 shape for MNK + op = cute.nvgpu.warp.MmaF16BF16Op( + self.ab_dtype, self.acc_dtype, self.mma_inst_shape + ) + + permutation_mnk = ( + self.atom_layout_mnk[0] * self.mma_inst_shape[0], + # if atom layout's N-mode is 1, to leverage the largest coalesced + # shared memory -> register copy, set the tiled mma's N mode to 16 + self.atom_layout_mnk[1] * self.mma_inst_shape[1] * 2, + self.atom_layout_mnk[2] * self.mma_inst_shape[2], + ) + + # Created a tiled mma that tiles the atom according to specified layout. + # For a 2x2x1 atom layout, the mma atom is duplicated 4 times, twice + # across M and twice across N + tC = cute.make_layout(self.atom_layout_mnk) + tiled_mma = cute.make_tiled_mma( + op, + tC, + permutation_mnk=permutation_mnk, + ) + + # grid_dim: ((m + BLK_M - 1) // BLK_M, (n + BLK_N - 1) // BLK_N, l) + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + + # Add threadblock rasterization to improve re-use of data + raster_factor = 1 + grid_dim_n = cute.size(grid_dim[1]) + # Thresholds picked so that it doesn't cause too many no-op CTAs + if grid_dim_n > 5: + raster_factor = 8 + elif grid_dim_n > 2: + raster_factor = 4 + elif grid_dim_n > 1: + raster_factor = 2 + rasterization_remap_grid_dim = ( + cute.size(grid_dim[0]) * raster_factor, + (cute.size(grid_dim[1]) + raster_factor - 1) // raster_factor, + cute.size(grid_dim[2]), + ) + + self.kernel( + mA, + mB, + mC, + sA_layout, + sB_layout, + sC_layout, + tiled_copy_A, + tiled_copy_B, + tiled_copy_C, + tiled_mma, + raster_factor, + epilogue_op, + ).launch( + grid=rasterization_remap_grid_dim, + block=[self.num_threads, 1, 1], + smem=smem_size, + ) + + @cute.kernel + def kernel( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: cute.Tensor, + sA_layout: cute.ComposedLayout, + sB_layout: cute.ComposedLayout, + sC_layout: cute.ComposedLayout, + tiled_copy_A: cute.TiledCopy, + tiled_copy_B: cute.TiledCopy, + tiled_copy_C: cute.TiledCopy, + tiled_mma: cute.TiledMma, + rasterization_factor: cutlass.Int32, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + # Thread index, block index + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, bidz = cute.arch.block_idx() + grid_dim = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + offset_tile_x, offset_tile_y = self.raster_tile( + bidx, bidy, rasterization_factor + ) + # Early exit if CTA is out of range + if grid_dim[0] <= offset_tile_x or grid_dim[1] <= offset_tile_y: + pass + else: + tiler_coord = (offset_tile_x, offset_tile_y, None) + + # /////////////////////////////////////////////////////////////////////////////// + # Get the appropriate tiles for this thread block. + # gA: (BLK_M, BLK_N, k), gB: (BLK_N, BLK_K, k), gC: (BLK_M, BLK_N) + # /////////////////////////////////////////////////////////////////////////////// + gA = cute.local_tile( + mA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + gB = cute.local_tile( + mB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + gC = cute.local_tile( + mC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + + # By default, if the tensor k mode does not divide into the tile k + # size, then last tiles in the k dimension are irregular. + # Instead, make the first tiles irregular when k is irregular. + # This allows us to handle the irregular tile first to avoid + # checking for this condition within the mainloop. + + # residual_k is a negative number indicating the amount needed to + # shift the pointer by in dimension k + residual_k = cute.size(mA, mode=[1]) - cutlass.Int32(self.bK) * cute.size( + gA, mode=[2] + ) + + # move the pointer of gA/gB in the `-k` direction + gA = cute.domain_offset((0, residual_k, 0), gA) + gB = cute.domain_offset((0, residual_k, 0), gB) + # input is 16B aligned + gA = cute.make_tensor(gA.iterator.align(16), gA.layout) + gB = cute.make_tensor(gB.iterator.align(16), gB.layout) + + # Construct identity layout for sA and sB (mirrors global tensors, + # used for predication only) + mcA = cute.make_identity_tensor(mA.layout.shape) + mcB = cute.make_identity_tensor(mB.layout.shape) + cA = cute.local_tile( + mcA[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, None, 1), + ) + cB = cute.local_tile( + mcB[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(None, 1, 1), + ) + + cA = cute.domain_offset((0, residual_k, 0), cA) + cB = cute.domain_offset((0, residual_k, 0), cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Create shared memory buffers and get the appropriate fragments for this thread. + # sA: (BLK_M, BLK_K, PIPE) , sB: (BLK_N, BLK_K, PIPE) + # tAgA: (CPY, CPY_M, CPY_K, k) , tBgB: (CPY, CPY_N, CPY_K, k) + # tAsA: (CPY, CPY_M, CPY_K, PIPE) , tBsB: (CPY, CPY_N, CPY_K, PIPE) + # /////////////////////////////////////////////////////////////////////////////// + # Shared memory buffer + smem = cutlass.utils.SmemAllocator() + + sA = smem.allocate_tensor(mA.element_type, sA_layout, 16) + sB = smem.allocate_tensor(mB.element_type, sB_layout, 16) + sC = cute.make_tensor( + cute.recast_ptr(sA.iterator, dtype=self.c_dtype), sC_layout + ) + + thr_copy_A = tiled_copy_A.get_slice(tidx) + thr_copy_B = tiled_copy_B.get_slice(tidx) + thr_copy_C = tiled_copy_C.get_slice(tidx) + tAgA = thr_copy_A.partition_S(gA) + tAsA = thr_copy_A.partition_D(sA) + tBgB = thr_copy_B.partition_S(gB) + tBsB = thr_copy_B.partition_D(sB) + tCsC_epilogue = thr_copy_C.partition_S(sC) + tCgC_epilogue = thr_copy_C.partition_D(gC) + + # Repeat the partitioning with identity layouts + tAcA = thr_copy_A.partition_S(cA) + tBcB = thr_copy_B.partition_S(cB) + + # /////////////////////////////////////////////////////////////////////////////// + # Predicate: Mark indices that need to copy when problem_shape isn't a multiple + # of tile_shape + # /////////////////////////////////////////////////////////////////////////////// + + # For predication over the tensors A (M/K), B (N/K), and (in the + # epilogue) C (M/N), we will compute it in a fashion similar to an + # outer product. The predication along one of the dimensions is + # evaluated and stored in a predication tensor. Then, the + # predication for the remaining dimension is handled later via an + # if/else branch at the copy. + # For A and B, predication booleans along M/N are stored in a + # predication tensor and along K is handled via a if/else branch. + + # Allocate predicate tensors for M and N. Predication is checked + # at the granularity of a copy atom, so the predicate tensor does not + # need separate booleans for individual elements within a copy + # atom (for example, the elements of tAgA.shape[0][0].) + tApA = cute.make_fragment( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + tBpB = cute.make_fragment( + cute.make_layout( + ( + tBsB.shape[0][1], + cute.size(tBsB, mode=[1]), + cute.size(tBsB, mode=[2]), + ), + stride=(cute.size(tBsB, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + # Set predicates for M/N bounds + for rest_v in range(tApA.shape[0]): + for m in range(tApA.shape[1]): + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], mA.shape[0] + ) + for rest_v in range(tBpB.shape[0]): + for n in range(tBpB.shape[1]): + tBpB[rest_v, n, 0] = cute.elem_less( + tBcB[(0, rest_v), n, 0, 0][0], mB.shape[0] + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Prefetch Prologue + # /////////////////////////////////////////////////////////////////////////////// + # Clear the smem tiles to account for predicated off loads + tAsA.fill(0) + tBsB.fill(0) + cute.arch.sync_threads() + # Start async loads for the first k-tile. Here we take care of the k residue + # via if/else check along the k dimension. Because we shifted the identity tensor + # by the residue_k and because the identity tensor is a coord tensor, the + # values of any identity tensor element that is poison is less than -1 + num_smem_stages = cute.size(tAsA, mode=[3]) + k_tile_count = cute.size(tAgA, mode=[3]) + k_tile_index = cutlass.Int32(0) + + for k in range(tApA.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tAcA[0, 0, k, 0][1]): + cute.copy( + tiled_copy_A, + tAgA[None, None, k, k_tile_index], + tAsA[None, None, k, 0], + pred=tApA[None, None, k], + ) + for k in range(tBpB.shape[2]): + if cute.elem_less(cutlass.Int32(-1), tBcB[0, 0, k, 0][1]): + cute.copy( + tiled_copy_B, + tBgB[None, None, k, k_tile_index], + tBsB[None, None, k, 0], + pred=tBpB[None, None, k], + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # Start async loads for rest of the k-tiles + for k_tile in range(1, num_smem_stages - 1): + if k_tile == k_tile_count: + tApA.fill(0) + tBpB.fill(0) + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, k_tile], + pred=tApA, + ) + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, k_tile], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + + # /////////////////////////////////////////////////////////////////////////////// + # Tile MMA compute thread partitions and allocate accumulators + # /////////////////////////////////////////////////////////////////////////////// + thr_mma = tiled_mma.get_slice(tidx) + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCsC = thr_mma.partition_C(sC) + tCgC = thr_mma.partition_C(gC) + tCrA = tiled_mma.make_fragment_A(tCsA[None, None, None, 0]) + tCrB = tiled_mma.make_fragment_B(tCsB[None, None, None, 0]) + tCrC = tiled_mma.make_fragment_C(tCgC) + # Clear the accumulator + tCrC.fill(0.0) + + # /////////////////////////////////////////////////////////////////////////////// + # Copy Atom A/B retiling + # /////////////////////////////////////////////////////////////////////////////// + + # Create the copy atoms for the copy from shared memory to register + atom_copy_s2r_A = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.a_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mA.element_type, + ) + atom_copy_s2r_B = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp( + self.b_major_mode != utils.LayoutEnum.ROW_MAJOR, 4 + ), + mB.element_type, + ) + + # Creates the tiled copy so that it matches the thread-value layout + # expected by the tiled mma + tiled_copy_s2r_A = cute.make_tiled_copy_A(atom_copy_s2r_A, tiled_mma) + tiled_copy_s2r_B = cute.make_tiled_copy_B(atom_copy_s2r_B, tiled_mma) + + thr_copy_ldmatrix_A = tiled_copy_s2r_A.get_slice(tidx) + thr_copy_ldmatrix_B = tiled_copy_s2r_B.get_slice(tidx) + tCsA_copy_view = thr_copy_ldmatrix_A.partition_S(sA) + tCrA_copy_view = thr_copy_ldmatrix_A.retile(tCrA) + tCsB_copy_view = thr_copy_ldmatrix_B.partition_S(sB) + tCrB_copy_view = thr_copy_ldmatrix_B.retile(tCrB) + + # Current pipe index in smem to read from / write to + smem_pipe_read = 0 + smem_pipe_write = num_smem_stages - 1 + + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + + # /////////////////////////////////////////////////////////////////////////////// + # PREFETCH register pipeline + # /////////////////////////////////////////////////////////////////////////////// + num_k_block = cute.size(tCrA, mode=[2]) + if num_k_block > 1: + # Wait until our first prefetched tile is loaded in + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + # Prefetch the first k-block rmem from the first k-tile + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, 0], + tCrA_copy_view[None, None, 0], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, 0], + tCrB_copy_view[None, None, 0], + ) + + # /////////////////////////////////////////////////////////////////////////////// + # Mainloop + # 1. Shared memory pipeline (gmem -> smem): + # The default smem pipeline depth is 3, meaning that for shared + # memory buffers, we allocate three times the size described by the + # CTA tiler. We prefetch 2 of these buffers before entering the main + # loop. Considering only the transfer from global memory to shared + # memory, the general structure of the mainloop is: + # (1) copy k-tile from gmem to smem; + # (2) perform gemm computation on k-tile; + # (3) wait for the next copy to finish. + # The `cute.arch.cp_async_wait_group(num_smem_stages - 2)` command + # waits for the number of unfinished 'copy' to be <= 1. The advantage + # of this approach is that it allows for simultaneous production + # (i.e., step (1)) and consumption (i.e., step (2)) of smem. + # A common misconception is to prefetch N buffers and rewrite + # the pipeline logic to wait on N-1 pending copies. The disadvantage + # of this approach is that it requires fully consuming a buffer in + # order to open an empty buffer for the next copy. + # 2. Register pipeline (smem -> register): + # Similarly, the register pipeline produces i+1, consumes i, and + # produces i+2... Notably, i and i+1 do not use the same register, + # eliminating dependencies on the same register for better parallelism. + # 3. Combining the smem and register pipelines results in the mainloop. + # /////////////////////////////////////////////////////////////////////////////// + for k_tile in range(k_tile_count): + for k_block in cutlass.range(num_k_block, unroll_full=True): + if k_block == num_k_block - 1: + tCsA_p = tCsA_copy_view[None, None, None, smem_pipe_read] + tCsB_p = tCsB_copy_view[None, None, None, smem_pipe_read] + cute.arch.cp_async_wait_group(num_smem_stages - 2) + cute.arch.sync_threads() + + # Load A, B from shared memory to registers for k_block + 1 + k_block_next = (k_block + 1) % num_k_block # static + cute.copy( + tiled_copy_s2r_A, + tCsA_p[None, None, k_block_next], + tCrA_copy_view[None, None, k_block_next], + ) + cute.copy( + tiled_copy_s2r_B, + tCsB_p[None, None, k_block_next], + tCrB_copy_view[None, None, k_block_next], + ) + + # Fetch next A: To better interleave global memory access and compute + # instructions, we intentionally use the sequence: copy A, perform GEMM, + # then copy B. + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_A, + tAgA[None, None, None, k_tile_index], + tAsA[None, None, None, smem_pipe_write], + pred=tApA, + ) + + # Thread-level register gemm for k_block + cute.gemm( + tiled_mma, + tCrC, + tCrA[None, None, k_block], + tCrB[None, None, k_block], + tCrC, + ) + + # Fetch next B and update smem pipeline read/write + if k_block == 0: + if k_tile + num_smem_stages - 1 < k_tile_count: + cute.copy( + tiled_copy_B, + tBgB[None, None, None, k_tile_index], + tBsB[None, None, None, smem_pipe_write], + pred=tBpB, + ) + k_tile_index = k_tile_index + 1 + cute.arch.cp_async_commit_group() + smem_pipe_write = smem_pipe_read + smem_pipe_read = smem_pipe_read + 1 + if smem_pipe_read == num_smem_stages: + smem_pipe_read = 0 + + # Sync before epilogue + cute.arch.cp_async_wait_group(0) + cute.arch.sync_threads() + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue with fusion + # /////////////////////////////////////////////////////////////////////////////// + tCrD = cute.make_fragment_like(tCrC, self.c_dtype) + tCrD[None] = epilogue_op(tCrC.load()).to(self.c_dtype) + + # Copy results of D back to shared memory + cute.autovec_copy(tCrD, tCsC) + + # Create coord tensor for C + ceilM, ceilN, _ = cute.ceil_div(mC.shape, (self.bM, self.bN, 1)) + mcC = cute.make_identity_tensor( + ( + cute.size(ceilM) * self.cta_tiler[0], + cute.size(ceilN) * self.cta_tiler[1], + 1, + ) + ) + cC = cute.local_tile( + mcC[None, None, bidz], + tiler=self.cta_tiler, + coord=tiler_coord, + proj=(1, 1, None), + ) + tCcC = thr_copy_C.partition_S(cC) + + tCrC_epilogue = cute.make_fragment_like(tCsC_epilogue) + # Wait for all writes to shared memory to finish before starting copies + # using the new layouts + cute.arch.sync_threads() + cute.autovec_copy(tCsC_epilogue, tCrC_epilogue) + + # Create predication tensor for m + tCpC = cute.make_fragment( + cute.make_layout( + ( + tCgC_epilogue.shape[0][1], + cute.size(tCgC_epilogue, mode=[1]), + cute.size(tCgC_epilogue, mode=[2]), + ), + stride=(cute.size(tCgC_epilogue, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + for rest_v in range(tCpC.shape[0]): + for m in range(tCpC.shape[1]): + tCpC[rest_v, m, 0] = cute.elem_less( + tCcC[(0, rest_v), m, 0][0], mC.shape[0] + ) + + # Copy to global memory using better vectorization + for rest_v in range(tCpC.shape[0]): + for n in range(tCpC.shape[2]): + if cute.elem_less(tCcC[(0, rest_v), 0, n][1], mC.shape[1]): + cute.copy( + tiled_copy_C, + tCrC_epilogue[None, None, n], + tCgC_epilogue[None, None, n], + pred=tCpC[None, None, n], + ) + return + + def _make_smem_layout_AB(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + major_mode_size = 64 if major_mode_size >= 64 else major_mode_size + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 3), + 0, + layout_atom_outer, + ) + layout = cute.tile_to_shape(layout_atom, smem_tiler, (0, 1, 2)) + return layout + + def _make_smem_layout_C(self, dtype, major_mode, copy_bits, smem_tiler): + major_mode_size = ( + smem_tiler[1] if major_mode == utils.LayoutEnum.ROW_MAJOR else smem_tiler[0] + ) + + swizzle_bits = int(math.log2(major_mode_size * dtype.width // copy_bits)) + swizzle_bits = min(swizzle_bits, 3) + + layout_atom_outer = ( + cute.make_layout((8, major_mode_size), stride=(major_mode_size, 1)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((major_mode_size, 8), stride=(1, major_mode_size)) + ) + layout_atom = cute.make_composed_layout( + cute.make_swizzle(swizzle_bits, 3, 4), + 0, + layout_atom_outer, + ) + + # Due to the thread layout of the mma, remove swizzle in C to + # prevent shared memory fragments owned by an single thread from + # holding swizzles + if major_mode == utils.LayoutEnum.COL_MAJOR: + layout_atom = cute.make_composed_layout( + cute.make_swizzle(0, 3, 4), 0, layout_atom_outer + ) + layout = cute.tile_to_shape( + layout_atom, + smem_tiler, + (0, 1), + ) + return layout + + def _make_gmem_tiled_copy_AB(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bK) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + # Value layout for copy + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def _make_gmem_tiled_copy_C(self, atom_copy, dtype, major_mode, copy_bits): + copy_elems = copy_bits // dtype.width + shape_dim_1 = cute.size(self.bN) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout( + (self.num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != utils.LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.bM) // copy_elems + thread_layout = cute.make_layout( + (shape_dim_0, self.num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == utils.LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_copy, thread_layout, value_layout) + + def raster_tile(self, i, j, f): + new_i = i // f + new_j = (i % f) + (j * f) + return (new_i, new_j) + + +def run( + a_major: str, + b_major: str, + c_major: str, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + mnkl: Tuple[int, int, int, int], + atom_layout_mnk: Tuple[int, int, int], + warmup_iterations: int = 2, + iterations: int = 100, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + print(f"Running Ampere tensor core GEMM example:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {ab_dtype}, B dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Atoms layout: {atom_layout_mnk}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + M, N, K, L = mnkl + + # Create and permute tensor A/B/C + def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + torch_tensor = ( + torch.empty(*shape, dtype=torch.int32) + .random_(-2, 2) + .to(dtype=cutlass_torch.dtype(dtype)) + .permute(permute_order) + .cuda() + ) + # assume input is 16B aligned + cute_tensor = ( + from_dlpack(torch_tensor, assumed_align=16) + .mark_layout_dynamic(leading_dim=(1 if not is_mode0_major else 0)) + .mark_compact_shape_dynamic( + mode=(1 if not is_mode0_major else 0), + stride_order=(2, 0, 1) if not is_mode0_major else (2, 1, 0), + divisibility=(128 // dtype.width), + ) + ) + return cute_tensor, torch_tensor + + mA, a_torch = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + mB, b_torch = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + mC, c_torch = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) + + tensor_op_gemm = TensorOpGemm( + ab_dtype, + c_dtype, + acc_dtype, + atom_layout_mnk, + ) + + print("Compiling kernel with cute.compile ...") + compiled_gemm = cute.compile(tensor_op_gemm, mA, mB, mC) + + print("Executing GEMM kernel...") + + if not skip_ref_check: + ref = torch.einsum( + "mkl,nkl->mnl", + a_torch.to(dtype=torch.float32), + b_torch.to(dtype=torch.float32), + ).to(cutlass_torch.dtype(c_dtype)) + compiled_gemm(mA, mB, mC) + print("Verifying results...") + torch.testing.assert_close(c_torch.cpu(), ref.cpu(), atol=1e-03, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + a_workspace, _ = create_and_permute_tensor(L, M, K, a_major == "m", ab_dtype) + b_workspace, _ = create_and_permute_tensor(L, N, K, b_major == "n", ab_dtype) + c_workspace, _ = create_and_permute_tensor(L, M, N, c_major == "m", c_dtype) + return testing.JitArguments(a_workspace, b_workspace, c_workspace) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + avg_time_us = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + warmup_iterations=warmup_iterations, + iterations=iterations, + use_cuda_graphs=False, + ) + + return avg_time_us # Return execution time in microseconds + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="example of multistage block matmul with CuTe on GPU" + ) + parser.add_argument( + "--mnkl", type=parse_comma_separated_ints, default=(112, 136, 40, 1) + ) + parser.add_argument( + "--atom_layout_mnk", type=parse_comma_separated_ints, default=(2, 2, 1) + ) + parser.add_argument( + "--ab_dtype", + type=cutlass.dtype, + choices=[cutlass.Float16], + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + choices=[cutlass.Float32], + default=cutlass.Float32, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + choices=[cutlass.Float16], + default=cutlass.Float16, + ) + parser.add_argument("--a_major", choices=["k", "m"], default="m") + parser.add_argument("--b_major", choices=["k", "n"], default="n") + parser.add_argument("--c_major", choices=["n", "m"], default="n") + parser.add_argument("--warmup_iterations", default=2, type=int) + parser.add_argument("--iterations", default=100, type=int) + parser.add_argument("--skip_ref_check", action="store_true") + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + run( + args.a_major, + args.b_major, + args.c_major, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.mnkl, + args.atom_layout_mnk, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c24b84265834ffe4062e14aaa255bc0c4610c9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_blockscaled_gemm_persistent.py @@ -0,0 +1,2466 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.runtime import from_dlpack + +""" +This example provides an experimental implementation of the SM100 batched dense blockscaled GEMM kernel, please note that the APIs and implementation details related to this kernel may change in future releases. + +A high-performance persistent batched dense blockscaled GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") for MXF8 input type and can only be row-major("K") for MXF4/NVF4 input type +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Matrix SFA layout is filled internally according to A shape and BlockScaledBasicChunk, which has M×ceil_div(K, sf_vec_size)×L elements respectively +- Matrix SFB layout is filled internally according to B shape and BlockScaledBasicChunk, which has N×ceil_div(K, sf_vec_size)×L elements respectively + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: + - Load scale factor A/B from shared memory (SMEM) to tensor memory (TMEM) using tcgen05.cp instruction. + - Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma.kind.block_scale instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Read scalefactor A from TMEM +- Read scalefactor B from TMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is shown below: + +.. code-block:: bash + + python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_blockscaled_gemm_persistent.py \ + --ab_dtype Float4E2M1FN --sf_dtype Float8E8M0FNU --sf_vec_size 16 \ + --c_dtype Float16 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,1024,1 \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints: +* Supported input data types: mxf8, mxf4, nvf4 + see detailed valid dtype combinations in below Sm100BlockScaledPersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type, mixed data type is not supported (e.g., mxf8 x mxf4) +* Mma tiler M must be 128 or 256(use_2cta_instrs) +* Mma tiler N must be 128 or 256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if Mma tiler M is 256(use_2cta_instrs) +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 16 and 32 for Float8 and Float4, respectively. +""" + + +class Sm100BlockScaledPersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x SFA x B x SFB) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported combinations of A/B data types, SF data typs and SF vector size: + - MXF8: A/B: Float8E5M2/Float8E4M3FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - MXF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU + sf_vec_size: 32 + - NVF4: A/B: Float4E2M1FN + SF: Float8E8M0FNU/Float8E4M3FN + sf_vec_size: 16 + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float32 + - Float16/BFloat16 + - Float8E4M3FN/Float8E5M2 + :note: Constraints: + - MMA tiler M must be 128 or 256 (use_2cta_instrs) + - MMA tiler N must be 128/256 + - Cluster shape M must be multiple of 2 if Mma tiler M is 256 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + - Also, Cluster shape M/N must be <= 4 for scale factor multicasts due to limited size of scale factors + + Example: + >>> gemm = Sm100BlockScaledPersistentDenseGemmKernel( + ... sf_vec_size=16, + ... mma_tiler_mn=(256, 128), + ... cluster_shape_mn=(2, 1) + ... ) + >>> gemm(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + sf_vec_size: int, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator, always set to Float32 + - sf_vec_size: Scalefactor A/B vector size. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param sf_vec_size: Scalefactor vector size. + :type sf_vec_size: int + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = cutlass.Float32 + self.sf_vec_size = sf_vec_size + self.use_2cta_instrs = mma_tiler_mn[0] == 256 + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B/SFA/SFB + - Computing epilogue subtile + - Setting up A/B/SFA/SFB/C stage counts in shared memory + - Computing A/B/SFA/SFB/C shared memory layout + - Computing tensor memory allocation columns + """ + # Compute mma instruction shapes + mma_inst_bits_k = 256 + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mnk = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_bits_k // self.a_dtype.width, + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mnk_sfb = ( + self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mnk[1], 128), + self.mma_inst_shape_mnk[2], + ) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mnk[0], + self.mma_inst_shape_mnk[1], + self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + ) + self.mma_tiler_sfb = ( + self.mma_inst_shape_mnk_sfb[0], + self.mma_inst_shape_mnk_sfb[1], + self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma_sfb.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.a_major_mode, + self.b_dtype, + self.b_major_mode, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sf_dtype, + self.sf_vec_size, + self.smem_capacity, + self.occupancy, + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + + @cute.jit + def __call__( + self, + a_tensor: cute.Tensor, + b_tensor: cute.Tensor, + sfa_tensor: cute.Tensor, + sfb_tensor: cute.Tensor, + c_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a_tensor: Input tensor A + :type a_tensor: cute.Tensor + :param b_tensor: Input tensor B + :type b_tensor: cute.Tensor + :param sfa_tensor: Scale factor tensor A + :type sfa_tensor: cute.Tensor + :param sfb_tensor: Scale factor tensor B + :type sfb_tensor: cute.Tensor + :param c_tensor: Output tensor C + :type c_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a_tensor.element_type + self.b_dtype: Type[cutlass.Numeric] = b_tensor.element_type + self.sf_dtype: Type[cutlass.Numeric] = sfa_tensor.element_type + self.c_dtype: Type[cutlass.Numeric] = c_tensor.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a_tensor).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b_tensor).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c_tensor) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF( + a_tensor.shape, self.sf_vec_size + ) + sfa_tensor = cute.make_tensor(sfa_tensor.iterator, sfa_layout) + + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF( + b_tensor.shape, self.sf_vec_size + ) + sfb_tensor = cute.make_tensor(sfb_tensor.iterator, sfb_layout) + + tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + + tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + cute.nvgpu.tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a_tensor, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b_tensor, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_( + self.sfa_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + sfa_tensor, + sfa_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mn, tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_( + self.sfb_smem_layout_staged, (None, None, None, 0) + ) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + sfb_tensor, + sfb_smem_layout, + self.mma_tiler_sfb, + tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes = ( + a_copy_size + b_copy_size + sfa_copy_size + sfb_copy_size + ) * atom_thr_size + + # Setup TMA store for C + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c_tensor, + epi_smem_layout, + self.epi_tile, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c_tensor, + self.cta_tile_shape_mnk, + self.cluster_shape_mn, + max_active_clusters, + ) + + self.buffer_align_bytes = 1024 + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.c_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfa_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[ + self.sf_dtype, cute.cosize(self.sfb_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tiled_mma_sfb, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + mSFA_mkl: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + mSFB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_sfa) + cpasync.prefetch_descriptor(tma_atom_sfb) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/SFA/SFB/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + + # + # Compute multicast mask for A/B/SFA/SFB buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + sfa_full_mcast_mask = None + sfb_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + sfa_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bK, RestM, RestK, RestL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = tiled_mma_sfb.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # TMA load SFA partition_S/D + sfa_cta_layout = a_cta_layout + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsSFA, tAgSFA = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfa, + block_in_cluster_coord_vmnk[2], + sfa_cta_layout, + cute.group_modes(sSFA, 0, 3), + cute.group_modes(tCgSFA, 0, 3), + ) + tAsSFA = cute.filter_zeros(tAsSFA) + tAgSFA = cute.filter_zeros(tAgSFA) + + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsSFB, tBgSFB = cute.nvgpu.cpasync.tma_partition( + tma_atom_sfb, + block_in_cluster_coord_sfb_vmnk[1], + sfb_cta_layout, + cute.group_modes(sSFB, 0, 3), + cute.group_modes(tCgSFB, 0, 3), + ) + tBsSFB = cute.filter_zeros(tBsSFB) + tBgSFB = cute.filter_zeros(tBgSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # ((atom_v, rest_v), RestK) + tAgSFA_slice = tAgSFA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgSFB_slice = tBgSFB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B/SFA/SFB + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + cute.copy( + tma_atom_sfa, + tAgSFA_slice[(None, ab_producer_state.count)], + tAsSFA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfa_full_mcast_mask, + ) + cute.copy( + tma_atom_sfb, + tBgSFB_slice[(None, ab_producer_state.count)], + tBsSFB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=sfb_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator/SFA/SFB tensor + # + # Make accumulator tmem tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout_staged, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # + # Partition for S2T copy of SFA/SFB + # + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ) + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = ( + self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # Copy SFA/SFB from smem to tmem + s2t_stage_coord = ( + None, + None, + None, + None, + ab_consumer_state.index, + ) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t_staged, + tCtSFA_compact_s2t, + ) + cute.copy( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t_staged, + tCtSFB_compact_s2t, + ) + + # tCtAcc += tCrA * tCrSFA * tCrB * tCrSFB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, kblock_idx) + tiled_mma.set( + tcgen05.Field.SFA, + tCtSFA[sf_kblock_coord].iterator, + ) + tiled_mma.set( + tcgen05.Field.SFB, + tCtSFB[sf_kblock_coord].iterator, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = ( + self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC_partitioned = ( + self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom( + tcgen05.Cp4x32x128bOp(self.cta_group), + self.sf_dtype, + ) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor( + tiled_copy_s2t, tCsSF_compact_s2t_ + ) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + a_major_mode: tcgen05.OperandMajorMode, + b_dtype: Type[cutlass.Numeric], + b_major_mode: tcgen05.OperandMajorMode, + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param a_major_mode: Major mode of operand A. + :type a_major_mode: tcgen05.OperandMajorMode + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param b_major_mode: Major mode of operand B. + :type b_major_mode: tcgen05.OperandMajorMode + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param sf_dtype: Data type of Scale factor. + :type sf_dtype: type[cutlass.Numeric] + :param sf_vec_size: Scale factor vector size. + :type sf_vec_size: int + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # ACC stages + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default C stages + num_c_stage = 2 + + # Calculate smem layout and size for one stage of A, B, SFA, SFB and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = ( + cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfa_smem_layout_staged_one) + + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + ) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B/SFA/SFB stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in { + cutlass.Float4E2M1FN, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_layouts( + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major dimension of the A tensor + :type a_major: str + :param b_major: The major dimension of the B tensor + :type b_major: str + :param c_major: The major dimension of the C tensor + :type c_major: str + + :return: True if the layouts are valid, False otherwise + :rtype: bool + """ + is_valid = True + + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not mma_tiler_mn[0] in [128, 256]: + is_valid = False + if not mma_tiler_mn[1] in [128, 256]: + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if mma_tiler_mn[0] == 256 else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + or cluster_shape_mn[0] > 4 + or cluster_shape_mn[1] > 4 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size + :type sf_vec_size: int + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype, sf_dtype, sf_vec_size, c_dtype + ): + can_implement = False + # Skip unsupported layouts + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_layouts( + ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockScaledPersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + return can_implement + + +@cute.jit +def cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + sf_ref_tensor: cute.Tensor, + sf_mma_tensor: cute.Tensor, +): + """Convert scale factor tensor from MKL layout to mma specification M(32x4xrest_m)xK(4xrest_k)xL layout""" + # sf_mma_tensor has flatten shape (32, 4, rest_m, 4, rest_k, l) + # group to ((32, 4, rest_m), (4, rest_k), l) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 0, 3) + sf_mma_tensor = cute.group_modes(sf_mma_tensor, 1, 3) + for i in cutlass.range(cute.size(sf_ref_tensor)): + mkl_coord = sf_ref_tensor.layout.get_hier_coord(i) + sf_mma_tensor[mkl_coord] = sf_ref_tensor[mkl_coord] + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense blockscaled GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: Data type for scale factor tensor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: Vector size for scale factor tensor + :type sf_vec_size: int + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape. + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print(f"Running Sm100 Persistent Dense BlockScaled GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, SF dtype: {sf_dtype}, SF Vec size: {sf_vec_size}") + print(f"C dtype: {c_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {sf_dtype}, {sf_vec_size}, {c_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create tensor A/B/C + a_ref = cutlass_torch.matrix(l, m, k, a_major == "m", cutlass.Float32) + b_ref = cutlass_torch.matrix(l, n, k, b_major == "n", cutlass.Float32) + c_ref = cutlass_torch.matrix(l, m, n, c_major == "m", cutlass.Float32) + + a_tensor, a_torch = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, b_torch = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, c_torch = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + # Create scale factor tensor SFA/SFB + def create_scale_factor_tensor(l, mn, k, sf_vec_size, dtype): + def ceil_div(a, b): + return (a + b - 1) // b + + sf_k = ceil_div(k, sf_vec_size) + ref_shape = (l, mn, sf_k) + + atom_m = (32, 4) + atom_k = 4 + mma_shape = ( + l, + ceil_div(mn, atom_m[0] * atom_m[1]), + ceil_div(sf_k, atom_k), + atom_m[0], + atom_m[1], + atom_k, + ) + + ref_permute_order = (1, 2, 0) + mma_permute_order = (3, 4, 1, 5, 2, 0) + + # Create f32 ref torch tensor (cpu) + ref_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + ref_shape, + torch.float32, + permute_order=ref_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=1, + max_val=3, + ), + ) + + # Create f32 cute torch tensor (cpu) + cute_f32_torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + mma_shape, + torch.float32, + permute_order=mma_permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0, + max_val=1, + ), + ) + + # convert ref f32 tensor to cute f32 tensor + cvt_sf_MKL_to_M32x4xrm_K4xrk_L( + from_dlpack(ref_f32_torch_tensor_cpu), + from_dlpack(cute_f32_torch_tensor_cpu), + ) + cute_f32_torch_tensor = cute_f32_torch_tensor_cpu.cuda() + + # reshape makes memory contiguous + ref_f32_torch_tensor_cpu = ( + ref_f32_torch_tensor_cpu.permute(2, 0, 1) + .unsqueeze(-1) + .expand(l, mn, sf_k, sf_vec_size) + .reshape(l, mn, sf_k * sf_vec_size) + .permute(*ref_permute_order) + ) + # prune to mkl for reference check. + ref_f32_torch_tensor_cpu = ref_f32_torch_tensor_cpu[:, :k, :] + + # Create dtype cute torch tensor (cpu) + cute_tensor, cute_torch_tensor = cutlass_torch.cute_tensor_like( + cute_f32_torch_tensor_cpu, + dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + # Convert f32 cute tensor to dtype cute tensor + cute_tensor = cutlass_torch.convert_cute_tensor( + cute_f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=True, + ) + return ref_f32_torch_tensor_cpu, cute_tensor, cute_torch_tensor + + sfa_ref, sfa_tensor, sfa_torch = create_scale_factor_tensor( + l, m, k, sf_vec_size, sf_dtype + ) + sfb_ref, sfb_tensor, sfb_torch = create_scale_factor_tensor( + l, n, k, sf_vec_size, sf_dtype + ) + + # Configure gemm kernel + gemm = Sm100BlockScaledPersistentDenseGemmKernel( + sf_vec_size, + mma_tiler_mn, + cluster_shape_mn, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + a_tensor, + b_tensor, + sfa_tensor, + sfb_tensor, + c_tensor, + max_active_clusters, + current_stream, + ) + + # Compute reference result + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_gemm( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + print("Verifying results...") + res_a = torch.einsum("mkl,mkl->mkl", a_ref, sfa_ref) + res_b = torch.einsum("nkl,nkl->nkl", b_ref, sfb_ref) + ref = torch.einsum("mkl,nkl->mnl", res_a, res_b) + + # Convert c back to f32 for comparison. + c_ref_device = c_ref.cuda() + cute.testing.convert( + c_tensor, + from_dlpack(c_ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if c_major == "n" else 0) + ), + ) + c_ref = c_ref_device.cpu() + + if c_dtype in (cutlass.Float32, cutlass.Float16, cutlass.BFloat16): + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + elif c_dtype in (cutlass.Float8E5M2, cutlass.Float8E4M3FN): + # Convert ref : f32 -> f8 -> f32 + ref_f8_ = torch.empty(*(l, m, n), dtype=torch.uint8, device="cuda").permute( + 1, 2, 0 + ) + ref_f8 = from_dlpack(ref_f8_, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + ref_f8.element_type = c_dtype + ref_device = ref.permute(2, 0, 1).contiguous().permute(1, 2, 0).cuda() + ref_tensor = from_dlpack(ref_device, assumed_align=16).mark_layout_dynamic( + leading_dim=1 + ) + cute.testing.convert(ref_tensor, ref_f8) + cute.testing.convert(ref_f8, ref_tensor) + ref = ref_device.cpu() + torch.testing.assert_close(c_ref, ref, atol=tolerance, rtol=1e-02) + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_ref, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_ref, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + + # Mark tensor to be byte aligned + a_tensor.mark_compact_shape_dynamic( + mode=1 if a_major == "k" else 0, + stride_order=(2, 0, 1) if a_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + b_tensor.mark_compact_shape_dynamic( + mode=1 if b_major == "k" else 0, + stride_order=(2, 0, 1) if b_major == "k" else (2, 1, 0), + divisibility=2 if ab_dtype == cutlass.Float4E2M1FN else 1, + ) + c_tensor.mark_compact_shape_dynamic( + mode=1 if c_major == "n" else 0, + stride_order=(2, 0, 1) if c_major == "n" else (2, 1, 0), + divisibility=2 if c_dtype == cutlass.Float4E2M1FN else 1, + ) + + _, sfa_tensor, _ = create_scale_factor_tensor(l, m, k, sf_vec_size, sf_dtype) + _, sfb_tensor, _ = create_scale_factor_tensor(l, n, k, sf_vec_size, sf_dtype) + return cute.testing.JitArguments( + a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, current_stream + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + sfa_torch.numel() * sfa_torch.element_size() + + sfb_torch.numel() * sfb_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = cute.testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = cute.testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Sm100 Dense Persistent BlockScaled GEMM." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(512, 256, 256, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float4E2M1FN) + parser.add_argument("--sf_dtype", type=cutlass.dtype, default=cutlass.Float8E8M0FNU) + parser.add_argument("--sf_vec_size", type=int, default=16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.sf_dtype, + args.sf_vec_size, + args.c_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..f5a8372950cb194521dd474a52a9dfeeff02fa83 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm.py @@ -0,0 +1,1890 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # + # Pipelining TMA load A/B and MMA mainloop + # + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) + + if warp_idx == 0: + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Prefetch TMA load A/B + # + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_tile = 0 + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # MMA mainloop + # + for k_tile in range(k_tile_cnt): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + if ab_producer_state.count < k_tile_cnt: + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + measure_launch_overhead=False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print(f"Running B100 Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py new file mode 100644 index 0000000000000000000000000000000000000000..f5022a82966bc93da97bbde24ab29cb598d5a914 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py @@ -0,0 +1,2193 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.cute.testing as testing +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + + +""" +A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + - Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is same as dense_gemm.py. + +.. code-block:: bash + + python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below PersistentDenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class PersistentDenseGemmKernel: + """This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = PersistentDenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilogue sync and tmem ptr sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) + if cutlass.const_expr(self.use_tma_store) + else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if cutlass.const_expr(self.use_tma_store) else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * ( + 2 if use_2cta_instrs else 1 + ) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if cutlass.const_expr(self.use_tma_store) + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Specialized TMA load warp + # + + if warp_idx == self.tma_warp_id: + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire( + ab_producer_state, peek_ab_empty_status + ) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_slice[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire( + ab_producer_state + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(k_tile_cnt): + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait( + ab_consumer_state, peek_ab_full_status + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait( + ab_consumer_state + ) + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # + # Alloc tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + epi_tidx = tidx + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_gC_partitioned = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + ( + simt_atom, + tTR_rC, + tTR_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition( + epi_tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + mma_tile_coord_mnl = ( + cur_tile_coord[0] // cute.size(tiled_mma.thr_id.shape), + cur_tile_coord[1], + cur_tile_coord[2], + ) + + # + # Slice to per mma tile index + # + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC_partitioned[ + ( + None, + None, + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for accumulator buffer full + # + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = (num_prev_subtiles + subtile_idx) % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + else: + # + # Convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy( + simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)] + ) + + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 2 + + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams( + num_ctas_mnl, cluster_shape_mnl + ) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param num_acc_stage: The stage of the accumulator tensor. + :type num_acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not PersistentDenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not PersistentDenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not PersistentDenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not PersistentDenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + use_2cta_instrs: bool = True, + use_tma_store: bool = True, + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + use_cold_l2: bool = False, + **kwargs, +): + """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param use_2cta_instrs: Whether to use 2CTA instructions. If not specified in the decorator parameters, the autotuner + will use the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_2cta_instrs: bool, optional + :param use_tma_store: Whether to use TMA store. If not specified in the decorator parameters, the autotuner will use + the default value of True. Otherwise, the autotuner will use the value specified in the decorator parameters. + :type use_tma_store: bool, optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print(f"Running Blackwell Persistent Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not PersistentDenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu + + a_ref, a_tensor, a_torch, a_torch_cpu = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch, b_torch_cpu = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch, c_torch_cpu = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = PersistentDenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, a_tensor, b_tensor, c_tensor, max_active_clusters, current_stream + ) + + if not skip_ref_check: + compiled_gemm(a_tensor, b_tensor, c_tensor, current_stream) + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + a_tensor, _ = cutlass_torch.cute_tensor_like( + a_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + b_tensor, _ = cutlass_torch.cute_tensor_like( + b_torch_cpu, ab_dtype, is_dynamic_layout=True, assumed_align=16 + ) + c_tensor, _ = cutlass_torch.cute_tensor_like( + c_torch_cpu, c_dtype, is_dynamic_layout=True, assumed_align=16 + ) + return testing.JitArguments(a_tensor, b_tensor, c_tensor, current_stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch_cpu.numel() * a_torch_cpu.element_size() + + b_torch_cpu.numel() * b_torch_cpu.element_size() + + c_torch_cpu.numel() * c_torch_cpu.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of Dense Persistent GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..8f5e172efaf7c7d13627a8fd7c7e5f05b8d90d90 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/dense_gemm_software_pipeline.py @@ -0,0 +1,1831 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Optional, Type, Tuple, Union +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL with compiler generated software pipeline. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. +4. Type convert C matrix to output type. +5. Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. +6. Optionally accept an elementwise lambda function epilogue_op to apply to the output tensor: + e.g., relu can set epilogue_op = lambda x: cute.where(x > 0, x, cute.full_like(x, 0)) + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Blackwell tcgen05 MMA tile shape used 2 cta with 256x128 +MMA tile and the cluster shape is (2,1). The input, mma accumulator and output +data type are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_software_pipeline.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_tma_store --use_2cta_instrs + +Constraints: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below DenseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class DenseGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tiler (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use Tensor Memory Access (TMA) for storing results + :type use_tma_store: bool + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = DenseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + 3. Output C tensor store mode: + - use_tma_store: Boolean indicating whether to use Tensor Memory Access (TMA) for storing results. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Use Tensor Memory Access (TMA) or normal store for output C tensor. + :type use_tma_store: bool + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.use_tma_store = use_tma_store + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.occupancy = 1 + self.threads_per_cta = 128 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + if cutlass.const_expr(self.use_tma_store): + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + else: + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Setup A/B/C stage count in shared memory + self.num_acc_stage, self.num_ab_stage, self.num_c_stage = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + self.use_tma_store, + ) + + # Compute A/B/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = ( + sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + if self.use_tma_store + else None + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if a.element_type is cutlass.Float32 else None + ), + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup store for C + tma_atom_c = None + tma_tensor_c = None + if cutlass.const_expr(self.use_tma_store): + c_cta_v_layout = cute.composition( + cute.make_identity_layout(c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + # Compute grid size + grid = self._compute_grid(c, self.cta_tile_shape_mnk, self.cluster_shape_mn) + + self.buffer_align_bytes = 1024 + + c_smem_size = ( + cute.cosize(self.c_smem_layout_staged.outer) if self.use_tma_store else 0 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c if self.use_tma_store else c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma descriptor + # + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if cutlass.const_expr(self.use_tma_store): + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coords outside cluster + cta_coord = (bidx, bidy, bidz) + mma_tile_coord_mnl = ( + cta_coord[0] // cute.size(tiled_mma.thr_id.shape), + cta_coord[1], + cta_coord[2], + ) + # Coords inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_full_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_full_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == 0: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = ( + storage.sC.get_tensor( + c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner + ) + if self.use_tma_store + else None + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestN, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + + # + # Alloc tensor memory buffer + # + if warp_idx == 0: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + cute.arch.barrier() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N) + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Partition for epilogue + # + tiled_copy_t2r, tTR_tAcc, tTR_rAcc = self.epilog_tmem_copy_and_partition( + tidx, tCtAcc, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC = None + tTR_gC = None + if cutlass.const_expr(self.use_tma_store): + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, tidx, sC + ) + tma_atom_c, bSG_sC, bSG_gC = self.epilog_gmem_copy_and_partition( + tidx, tma_atom_c, tCgC, epi_tile, sC + ) + else: + simt_atom, tTR_rC, tTR_gC = self.epilog_gmem_copy_and_partition( + tidx, tiled_copy_t2r, tCgC, epi_tile, sC + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), RestK) + tBgB = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + if cutlass.const_expr(self.use_tma_store): + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC[(None, None, None, *mma_tile_coord_mnl)] + else: + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_gC = tTR_gC[(None, None, None, None, None, *mma_tile_coord_mnl)] + + # /////////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # /////////////////////////////////////////////////////////////////////////////// + prefetch_k_tile_cnt = cutlass.min(self.num_ab_stage - 2, k_tile_cnt) + if warp_idx == 0: + for k_tile in cutlass.range( + k_tile_cnt, + prefetch_stages=self.num_ab_stage - 2, + ): + # wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA[(None, ab_producer_state.count)], + tAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_producer_state.count)], + tBsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + mcast_mask=b_full_mcast_mask, + ) + + if is_leader_cta: + # Wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, ab_consumer_state.index) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + ab_producer_state.advance() + ab_consumer_state.advance() + + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # + # Epilogue + # + + # Release tensor memory allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + if cutlass.const_expr(self.use_tma_store): + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + else: + tTR_gC = cute.group_modes(tTR_gC, 3, cute.rank(tTR_gC)) + + c_pipeline = None + if cutlass.const_expr(self.use_tma_store): + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + if cutlass.const_expr(self.use_tma_store): + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)]) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier() + + # + # TMA store C to global memory + # + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure TMA store is completed to recollect C buffer + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + cute.arch.barrier() + else: + # + # Perform epilogue op on accumulator and convert to C type + # + acc_vec = tTR_rAcc.load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tTR_rC.store(acc_vec) + + # + # Store C to global memory + # + cute.copy(simt_atom, tTR_rC, tTR_gC[(None, None, None, subtile_idx)]) + + # + # Dealloc the tensor memory buffer + # + cute.arch.barrier() + if warp_idx == 0: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait for C store complete + # + if cutlass.const_expr(self.use_tma_store): + c_pipeline.producer_tail() + + # + # Wait A/B buffer empty + # + if warp_idx == 0: + # Reverse prefetch_k_tile_cnt times to next available buffer + for i in range(prefetch_k_tile_cnt): + ab_producer_state.reverse() + ab_pipeline.producer_tail(ab_producer_state) + return + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing either: + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + - For non-TMA store: (simt_atom, tTR_rC, tTR_gC) where: + - simt_atom: The SIMT copy atom + - tTR_rC: The register tensor C + - tTR_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + if cutlass.const_expr(self.use_tma_store): + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + else: + tiled_copy_t2r = atom + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTR_gC = thr_copy_t2r.partition_D(gC_epi) + # (T2R, T2R_M, T2R_N) + tTR_rC = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.c_dtype + ) + simt_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), self.c_dtype) + return simt_atom, tTR_rC, tTR_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + use_tma_store: bool, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tile. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + :param use_tma_store: Whether TMA store is enabled. + :type use_tma_store: bool + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 1 + # Default C stages + num_c_stage = 2 if use_tma_store else 0 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + if use_tma_store + else None + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + c_bytes_per_stage = ( + cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + if use_tma_store + else 0 + ) + c_bytes = c_bytes_per_stage * num_c_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if use_tma_store: + num_c_stage += ( + smem_capacity + - ab_bytes_per_stage * num_ab_stage + - (occupancy + 1) * (mbar_helpers_bytes + c_bytes) + ) // ((occupancy + 1) * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + ) -> Tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + cluster_shape_mnl = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(c.layout.shape[0], cta_tile_shape_mnk[0]), + cute.ceil_div(c.layout.shape[1], cta_tile_shape_mnk[1]), + c.layout.shape[2], + ), + cluster_shape_mnl, + ) + + return grid + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, mma_tiler: Tuple[int, int, int] + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(acc_shape) + return sm100_utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype + not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == cutlass.Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if ( + acc_dtype == cutlass.Float32 + and c_dtype + not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == cutlass.Int32 + and c_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + cutlass.Float32, + cutlass.Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_epilog_store_option( + use_2cta_instrs: bool, + use_tma_store: bool, + m: int, + n: int, + mma_tiler_mn: Tuple[int, int], + ) -> bool: + """ + Check if the epilogue store option is valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + + :return: True if the epilogue store option is valid, False otherwise + :rtype: bool + """ + + is_valid = True + # None TMA store version does not have predication, can not support OOB tiles + cta_tile_shape_mn = ( + mma_tiler_mn[0] // (2 if use_2cta_instrs else 1), + mma_tiler_mn[1], + ) + if not use_tma_store: + if not (m % cta_tile_shape_mn[0] == 0 and n % cta_tile_shape_mn[1] == 0): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_tma_store: bool, + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param use_tma_store: Whether to use TMA store + :type use_tma_store: bool + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not DenseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not DenseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not DenseGemmKernel.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip invalid epilogue store option + if not DenseGemmKernel.is_valid_epilog_store_option( + use_2cta_instrs, use_tma_store, m, n, mma_tiler_mn + ): + can_implement = False + return can_implement + + +def run_dense_gemm( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + use_2cta_instrs: bool, + use_tma_store: bool, + tolerance: float, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + """ + print(f"Running B100 software pipeline Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Use TMA Store: {'True' if use_tma_store else 'False'}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not DenseGemmKernel.can_implement( + ab_dtype, + acc_dtype, + c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + m, + n, + k, + l, + a_major, + b_major, + c_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {c_dtype}, {use_2cta_instrs}, {mma_tiler_mn}, {cluster_shape_mn}, {use_tma_store}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {c_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.RANDOM, + init_config=cutlass_torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a_ref, a_tensor, a_torch = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + l, m, n, c_major == "m", c_dtype, is_dynamic_layout=True + ) + + # Configure gemm kernel + gemm = DenseGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + use_tma_store, + ) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile(gemm, a_tensor, b_tensor, c_tensor, stream) + + # Launch GPU kernel + # Warm up + for i in range(warmup_iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + # Execution + for i in range(iterations): + compiled_gemm(a_tensor, b_tensor, c_tensor, stream) + + # Compute reference result + if not skip_ref_check: + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = (torch.einsum("mkl,nkl->mnl", a_ref, b_ref)).cpu() + + # Copy gpu result back + gpu_c = c_torch.cpu() + + # Convert ref to c_type + if c_dtype == cutlass.Float32: + ref_c = ref + elif c_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + # Reference checking ref_c and gpu_c + torch.testing.assert_close( + gpu_c, + ref_c, + atol=tolerance, + rtol=1e-05, + ) + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + # or: return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tiler (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.TFloat32) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--use_tma_store", action="store_true", help="Use tma store or not" + ) + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument("--iterations", type=int, default=1, help="Iterations") + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run_dense_gemm( + args.mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + args.use_tma_store, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py new file mode 100644 index 0000000000000000000000000000000000000000..560129dfdcd4d065710b4d0bb08e638df3d51acc --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/fmha.py @@ -0,0 +1,3100 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import enum +import math +import time +from typing import Type, Tuple + +import torch +import torch.nn.functional as F +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.nvgpu.tcgen05 as tcgen05 +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.cute.testing as testing +from cutlass.cute.runtime import from_dlpack +from cutlass.cute.typing import Int32, Int64, Float32, Boolean + +""" +A fused multi-head attention (FMHA) example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of fused multi-head attention using a TMA + Blackwell SM100 +TensorCore warp-specialized persistent kernel. The implementation integrates the Q*K^T matrix multiplication, +softmax normalization, and softmax(Q*K^T)*V into a single kernel, avoiding intermediate data movement between +global memory and shared memory, thus improving computational efficiency. + +The kernel implements key optimizations including: +- Warp specialization for different computation phases (load, MMA, softmax, correction, epilogue) +- Pipeline stages between different warps for overlapping computation and memory access +- Support for different precision data types +- Optional causal masking for autoregressive models + +To run this example: + +.. code-block:: bash + + python examples/blackwell/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 128,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent + +The above example runs FMHA with batch size 4, sequence length 1024, 8 attention heads, and head +dimension 64. The Blackwell tcgen05 MMA tile shape is (128, 128), and the kernel uses fp16 for input/output +with fp32 for accumulation. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/fmha.py \ + --qk_acc_dtype Float32 --pv_acc_dtype Float32 \ + --mma_tiler_mn 128,128 \ + --q_shape 4,1024,8,64 --k_shape 4,1024,8,64 \ + --is_persistent --warmup_iterations 10 \ + --iterations 10 --skip_ref_check + +Constraints for this example: +* Supported head dimensions: 32, 64, and 128 +* Number of heads in Q must be divisible by number of heads in K +* mma_tiler_mn must be 128,128 +* Batch size must be the same for Q, K, and V tensors +* For causal masking, use --is_causal (note: specify without =True/False) +* For persistent scheduling, use --is_persistent (note: specify without =True/False) +""" + +class FmhaStaticTileSchedulerParams: + def __init__( + self, + is_persistent: bool, + problem_shape_mbh: cute.Shape, + *, + loc=None, + ip=None, + ): + self.is_persistent = is_persistent + self.problem_shape_mbh = problem_shape_mbh + self._loc = loc + self._ip = ip + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.is_persistent, self.problem_shape_mbh]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.is_persistent, self.problem_shape_mbh], self._values_pos + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return FmhaStaticTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + +def create_fmha_static_tile_scheduler_params( + is_persistent: bool, + problem_shape_mbh: cute.Shape, +) -> FmhaStaticTileSchedulerParams: + return FmhaStaticTileSchedulerParams(is_persistent, problem_shape_mbh) + +class FmhaStaticTileScheduler: + + def __init__( + self, + params: FmhaStaticTileSchedulerParams, + current_work_linear_idx: Int32, + blk_coord: cute.Coord, + grid_shape: cute.Shape, + *, + loc=None, + ip=None, + ): + self._params = params + self._blk_coord = blk_coord + self._grid_shape = grid_shape + self._is_persistent = params.is_persistent + self._current_work_linear_idx = current_work_linear_idx + self._problem_shape_mbh = cute.make_layout( + params.problem_shape_mbh, loc=loc, ip=ip + ) + self._num_blocks = cute.size(self._problem_shape_mbh, loc=loc, ip=ip) + self._is_first_block = True + self.num_persistent_sm = cute.size(grid_shape, loc=loc, ip=ip) + self._loc = loc + self._ip = ip + + # called by host + @staticmethod + def get_grid_shape( + params: FmhaStaticTileSchedulerParams, + *, + loc=None, + ip=None, + ) -> cute.Shape: + if params.is_persistent: + hardware_info = cutlass.utils.HardwareInfo() + sm_count = hardware_info.get_device_multiprocessor_count() + return ( + cutlass.min( + sm_count, cute.size(params.problem_shape_mbh, loc=loc, ip=ip) + ), + 1, + 1, + ) + else: + return params.problem_shape_mbh + + @staticmethod + def check_valid_work_for_seqlen_q( + q_tiler: int, + current_idx: Int32, + seqlen_q: Int32, + ) -> Boolean: + return current_idx * q_tiler < seqlen_q + + def get_current_work(self, *, loc=None, ip=None) -> utils.WorkTileInfo: + is_valid = ( + self._current_work_linear_idx < self._num_blocks + if self._is_persistent + else self._is_first_block + ) + + blk_coord = (0, 0, 0) + if self._is_persistent: + blk_coord = self._problem_shape_mbh.get_hier_coord( + self._current_work_linear_idx, loc=loc, ip=ip + ) + else: + blk_coord = self._blk_coord + + # cur_tile_coord is (mid, 0, (bid, hid)) + cur_tile_coord = ( + blk_coord[0], + 0, + (blk_coord[1], blk_coord[2]), + ) + + return utils.WorkTileInfo(cur_tile_coord, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + def advance_to_next_work(self, *, advance_count=1, loc=None, ip=None): + if self._is_persistent: + self._current_work_linear_idx += advance_count * self.num_persistent_sm + self._is_first_block = False + + def __extract_mlir_values__(self): + values = cutlass.extract_mlir_values(self._params) + values.extend(cutlass.extract_mlir_values(self._current_work_linear_idx)) + values.extend(cutlass.extract_mlir_values(self._blk_coord)) + values.extend(cutlass.extract_mlir_values(self._grid_shape)) + return values + + def __new_from_mlir_values__(self, values): + assert len(values) == 10 + new_params = cutlass.new_from_mlir_values(self._params, values[0:3]) + new_current_work_linear_idx = cutlass.new_from_mlir_values( + self._current_work_linear_idx, [values[3]] + ) + new_blk_coord = cutlass.new_from_mlir_values(self._blk_coord, values[4:7]) + new_grid_shape = cutlass.new_from_mlir_values(self._grid_shape, values[7:]) + return FmhaStaticTileScheduler( + new_params, new_current_work_linear_idx, new_blk_coord, new_grid_shape + ) + + +def create_fmha_static_tile_scheduler( + params: FmhaStaticTileSchedulerParams, + blk_coord: cute.Coord, + grid_shape: cute.Shape, +) -> FmhaStaticTileScheduler: + return FmhaStaticTileScheduler(params, blk_coord[0], blk_coord, grid_shape) + + +class MaskType(enum.Enum): + NO_MASK = enum.auto() + RESIDUAL_MASK = enum.auto() + CAUSAL_MASK = enum.auto() + + +def make_thread_cooperative_group(size: int): + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size, size) + + +class BlackwellFusedMultiHeadAttentionForward: + def __init__( + self, + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler: Tuple[int, int, int], + is_persistent: bool, + mask_type: MaskType, + ): + """Initializes the configuration for a Blackwell Fused Multi-Head Attention (FMHA) kernel. + + This configuration includes several key aspects: + + 1. Data Type Settings: + - qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + - pv_acc_dtype: Data type for P*V matrix multiplication accumulator + + 2. MMA Instruction Settings: + - mma_tiler: The (M, N, K) shape of the MMA instruction unit + - qk_mma_tiler: MMA shape for Q*K^T computation + - pv_mma_tiler: MMA shape for P*V computation + + 3. Kernel Execution Mode: + - is_persistent: Boolean indicating whether to use persistent kernel mode + - mask_type: Specifies the type of mask to use (no mask, residual mask, or causal mask) + + :param qk_acc_dtype: Data type for Q*K^T matrix multiplication accumulator + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Data type for P*V matrix multiplication accumulator + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler: The (M, N, K) shape of the MMA instruction + :type mma_tiler: Tuple[int, int, int] + :param is_persistent: Whether to use persistent kernel mode + :type is_persistent: bool + :param mask_type: Type of mask to use + :type mask_type: MaskType + """ + + self.qk_acc_dtype = qk_acc_dtype + self.pv_acc_dtype = pv_acc_dtype + self.cta_tiler = ( + 2 * mma_tiler[0], # 2 Q tile per CTA + mma_tiler[1], + mma_tiler[2], + ) + self.qk_mma_tiler = mma_tiler + self.pv_mma_tiler = ( + mma_tiler[0], + mma_tiler[2], + mma_tiler[1], + ) + self.cluster_shape_mn = (1, 1) + self.is_persistent = is_persistent + self.mask_type = mask_type + self.softmax0_warp_ids = (0, 1, 2, 3) + self.softmax1_warp_ids = (4, 5, 6, 7) + self.correction_warp_ids = (8, 9, 10, 11) + self.mma_warp_id = 12 + self.load_warp_id = 13 + self.epilogue_warp_id = 14 + self.empty_warp_id = 15 + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + self.mma_warp_id, + self.load_warp_id, + self.epilogue_warp_id, + self.empty_warp_id, + ) + ) + + self.cta_sync_bar_id = 0 + self.tmem_alloc_sync_bar_id = 1 + + self.tmem_s0_offset = 0 + self.tmem_s1_offset = 128 + self.tmem_o0_offset = 256 + self.tmem_o1_offset = 384 + self.tmem_p0_offset = 32 + self.tmem_p1_offset = 160 + + # vec buffer for row_max & row_sum + self.tmem_vec0_offset = 0 + self.tmem_vec1_offset = 128 + + self.num_regs_softmax = 192 + self.num_regs_correction = 96 + self.num_regs_other = 32 + self.num_regs_empty = 24 + + self.buffer_align_bytes = 1024 + + num_warps_per_warpgroup = 4 + self.softmax_warpgroup_count = ( + len((*self.softmax0_warp_ids, *self.softmax1_warp_ids)) + // num_warps_per_warpgroup + ) + + def _setup_attributes(self): + """Set up configurations and parameters for the FMHA kernel operation. + + This method initializes and configures various attributes required for the + execution of the fused multi-head attention kernel, mainly about the pipeline stages: + + - Sets up staging parameters for Q, K, V inputs and accumulator data + - Configures pipeline stages for softmax, correction, and epilogue operations + """ + + self.q_stage = 2 + self.kv_stage = 4 if self.q_dtype.width == 8 else 3 + self.acc_stage = 1 + self.softmax_corr_stage = 1 + self.mma_corr_stage = 2 + self.mma_softmax_stage = 1 + self.epi_stage = 2 + + @cute.jit + def __call__( + self, + q_iter: cute.Pointer, + k_iter: cute.Pointer, + v_iter: cute.Pointer, + o_iter: cute.Pointer, + problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32], + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + stream: cuda.CUstream, + ): + """Execute the Fused Multi-Head Attention operation on the provided tensors. + + This method prepares the input tensors for processing, validates their shapes and types, + configures the computation parameters, and launches the CUDA kernel. + + The method handles: + 1. Tensor layout transformations for specific memory access patterns + 2. Validation of tensor shapes and data types + 3. Initialization of hardware-specific parameters and memory layouts + 4. Configuration of TMA (Tensor Memory Access) operations + 5. Grid and work scheduling computation + 6. Kernel launch with appropriate parameters + + :param q_iter: The query tensor pointer + :type q_iter: cute.Pointer + :param k_iter: The key tensor pointer + :type k_iter: cute.Pointer + :param v_iter: The value tensor pointer + :type v_iter: cute.Pointer + :param o_iter: The output tensor pointer + :type o_iter: cute.Pointer + :param problem_size: The problem size with shape [b, s_q, s_k, h_q, h_k, d]. If cum_seqlen_q or cum_seqlen_k is not None, s_q and s_k are the max of the cumulative sequence length respectively. + :type problem_size: Tuple[Int32, Int32, Int32, Int32, Int32, Int32] + :param cum_seqlen_q: The cumulative sequence length tensor for query + :type cum_seqlen_q: cute.Tensor | None + :param cum_seqlen_k: The cumulative sequence length tensor for key + :type cum_seqlen_k: cute.Tensor | None + :param scale_softmax_log2: The log2 scale factor for softmax + :type scale_softmax_log2: Float32 + :param scale_output: The scale factor for the output + :type scale_output: Float32 + :param stream: The CUDA stream to execute the kernel on + :type stream: cuda.CUstream + :raises TypeError: If tensor data types don't match or aren't supported + :raises RuntimeError: If tensor layouts aren't in supported formats + """ + b, s_q, s_k, h_q, h_k, d = problem_size + h_r = h_q // h_k + qo_offset = 0 if cum_seqlen_q is None else -s_q * d * h_r * h_k + kv_offset = 0 if cum_seqlen_k is None else -s_k * d * h_k + b_qo = b if cum_seqlen_q is None else s_q * (1 + b) + b_kv = b if cum_seqlen_k is None else s_k * (1 + b) + stride_b_qo = h_r * h_k * s_q * d if cum_seqlen_q is None else d * h_r * h_k + stride_b_kv = h_k * s_k * d if cum_seqlen_k is None else d * h_k + + # (s, d, ((h_r, h_k), b)) + q_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), + ) + q = cute.make_tensor(q_iter + qo_offset, q_layout) + # (s, d, ((h_r, h_k), b)), 0-stride for h_r to broadcast + k_layout = cute.make_layout( + (s_k, d, ((h_r, h_k), b_kv)), + stride=(d * h_k, 1, ((0, d), stride_b_kv)), + ) + k = cute.make_tensor(k_iter + kv_offset, k_layout) + # (d, s, ((h_r, h_k), b)), 0-stride for h_r to broadcast + v_layout = cute.make_layout( + (d, s_k, ((h_r, h_k), b_kv)), + stride=(1, d * h_k, ((0, d), stride_b_kv)), + ) + v = cute.make_tensor(v_iter + kv_offset, v_layout) + # (s, d, ((h_r, h_k), b)) + o_layout = cute.make_layout( + (s_q, d, ((h_r, h_k), b_qo)), + stride=(d * h_r * h_k, 1, ((d, d * h_r), stride_b_qo)), + ) + o = cute.make_tensor(o_iter + qo_offset, o_layout) + + # setup static attributes before smem/grid/tma computation + self.q_dtype = q.element_type + self.k_dtype = k.element_type + self.v_dtype = v.element_type + self.o_dtype = o.element_type + + self.tile_sched_params, grid = self._compute_grid( + cute.shape((s_q, d, ((h_r, h_k), b))), + self.cta_tiler, + self.is_persistent, + ) + + self.q_major_mode = utils.LayoutEnum.from_tensor(q).mma_major_mode() + self.k_major_mode = utils.LayoutEnum.from_tensor(k).mma_major_mode() + self.v_major_mode = utils.LayoutEnum.from_tensor(v).mma_major_mode() + self.o_layout = utils.LayoutEnum.from_tensor(o) + + if cutlass.const_expr(self.q_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of q is not supported") + if cutlass.const_expr(self.k_major_mode != tcgen05.OperandMajorMode.K): + raise RuntimeError("The layout of k is not supported") + if cutlass.const_expr(self.v_major_mode != tcgen05.OperandMajorMode.MN): + raise RuntimeError("The layout of v is not supported") + + # check type consistency + if cutlass.const_expr(self.q_dtype != self.k_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.k_dtype}") + if cutlass.const_expr(self.q_dtype != self.v_dtype): + raise TypeError(f"Type mismatch: {self.q_dtype} != {self.v_dtype}") + self._setup_attributes() + + cta_group = tcgen05.CtaGroup.ONE + # the intermediate tensor p is from tmem & k-major + p_source = tcgen05.OperandSource.TMEM + p_major_mode = tcgen05.OperandMajorMode.K + qk_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.q_dtype, + self.q_major_mode, + self.k_major_mode, + self.qk_acc_dtype, + cta_group, + self.qk_mma_tiler[:2], + ) + pv_tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.v_dtype, + p_major_mode, + self.v_major_mode, + self.pv_acc_dtype, + cta_group, + self.pv_mma_tiler[:2], + p_source, + ) + + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (qk_tiled_mma.thr_id.shape,), + ) + + self.epi_tile = self.pv_mma_tiler[:2] + + q_smem_layout_staged = sm100_utils.make_smem_layout_a( + qk_tiled_mma, + self.qk_mma_tiler, + self.q_dtype, + self.q_stage, + ) + k_smem_layout_staged = sm100_utils.make_smem_layout_b( + qk_tiled_mma, + self.qk_mma_tiler, + self.k_dtype, + self.kv_stage, + ) + p_tmem_layout_staged = sm100_utils.make_smem_layout_a( + pv_tiled_mma, + self.pv_mma_tiler, + self.q_dtype, + self.acc_stage, + ) + v_smem_layout_staged = sm100_utils.make_smem_layout_b( + pv_tiled_mma, + self.pv_mma_tiler, + self.v_dtype, + self.kv_stage, + ) + o_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.o_dtype, + self.o_layout, + self.epi_tile, + self.epi_stage, + ) + + # TMA load for Q + tma_load_op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(cta_group) + tma_store_op = cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp() + + q_smem_layout = cute.select(q_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_q, tma_tensor_q = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + q, + q_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # TMA load for K + k_smem_layout = cute.select(k_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_k, tma_tensor_k = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + k, + k_smem_layout, + self.qk_mma_tiler, + qk_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # TMA load for V + v_smem_layout = cute.select(v_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_v, tma_tensor_v = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + v, + v_smem_layout, + self.pv_mma_tiler, + pv_tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + o_cta_v_layout = cute.composition( + cute.make_identity_layout(o.shape), self.epi_tile + ) + o_smem_layout = cute.select(o_smem_layout_staged, mode=[0, 1]) + + tma_atom_o, tma_tensor_o = cute.nvgpu.cpasync.make_tiled_tma_atom( + tma_store_op, + o, + o_smem_layout, + o_cta_v_layout, + ) + + q_copy_size = cute.size_in_bytes(self.q_dtype, q_smem_layout) + k_copy_size = cute.size_in_bytes(self.k_dtype, k_smem_layout) + self.tma_copy_q_bytes = q_copy_size + self.tma_copy_kv_bytes = k_copy_size + + @cute.struct + class SharedStorage: + # Pipeline barriers + load_q_mbar_ptr: cute.struct.MemRange[Int64, self.q_stage * 2] + load_kv_mbar_ptr: cute.struct.MemRange[Int64, self.kv_stage * 2] + mma_s0_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + mma_s1_mbar_ptr: cute.struct.MemRange[Int64, self.mma_softmax_stage * 2] + s0_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] + s1_corr_mbar_ptr: cute.struct.MemRange[Int64, self.softmax_corr_stage * 2] + s0_s1_sequence_mbar_ptr: cute.struct.MemRange[ + Int64, self.softmax_warpgroup_count + ] + corr_epi_mbar_ptr: cute.struct.MemRange[Int64, self.epi_stage * 2] + mma_corr_mbar_ptr: cute.struct.MemRange[Int64, self.mma_corr_stage * 2] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[Int64, 1] + # Tmem holding buffer + tmem_holding_buf: Int32 + # Smem tensors + sO: cute.struct.Align[ + cute.struct.MemRange[self.o_dtype, cute.cosize(o_smem_layout_staged)], + self.buffer_align_bytes, + ] + sQ: cute.struct.Align[ + cute.struct.MemRange[self.q_dtype, cute.cosize(q_smem_layout_staged)], + self.buffer_align_bytes, + ] + sK: cute.struct.Align[ + cute.struct.MemRange[self.k_dtype, cute.cosize(k_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + qk_tiled_mma, + pv_tiled_mma, + tma_atom_q, + tma_tensor_q, + tma_atom_k, + tma_tensor_k, + tma_atom_v, + tma_tensor_v, + tma_atom_o, + tma_tensor_o, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + q_smem_layout_staged, + k_smem_layout_staged, + p_tmem_layout_staged, + v_smem_layout_staged, + o_smem_layout_staged, + self.tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + qk_tiled_mma: cute.TiledMma, + pv_tiled_mma: cute.TiledMma, + tma_atom_q: cute.CopyAtom, + mQ_qdl: cute.Tensor, + tma_atom_k: cute.CopyAtom, + mK_kdl: cute.Tensor, + tma_atom_v: cute.CopyAtom, + mV_dkl: cute.Tensor, + tma_atom_o: cute.CopyAtom, + mO_qdl: cute.Tensor, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + scale_output: Float32, + q_smem_layout_staged: cute.ComposedLayout, + k_smem_layout_staged: cute.ComposedLayout, + p_tmem_layout_staged: cute.ComposedLayout, + v_smem_layout_staged: cute.ComposedLayout, + o_smem_layout_staged: cute.ComposedLayout, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """The device kernel implementation of the Fused Multi-Head Attention. + + This kernel coordinates multiple specialized warps to perform different phases of the FMHA computation: + 1. Load warp: Loads Q, K, V data from global memory to shared memory using TMA + 2. MMA warp: Performs matrix multiplications (Q*K^T and P*V) + 3. Softmax warps: Compute softmax normalization on attention scores + 4. Correction warps: Apply adjustments to intermediate results + 5. Epilogue warp: Handles final output transformation and storage + + The kernel implements a complex pipeline with overlapping computation and memory operations, + using tensor memory access (TMA) for efficient data loading, warp specialization for different + computation phases, and optional attention masking. + + :param qk_tiled_mma: Tiled MMA for Q*K^T + :type qk_tiled_mma: cute.TiledMma + :param pv_tiled_mma: Tiled MMA for P*V + :type pv_tiled_mma: cute.TiledMma + :param tma_atom_q: TMA copy atom for query tensor + :type tma_atom_q: cute.CopyAtom + :param mQ_qdl: Partitioned query tensor + :type mQ_qdl: cute.Tensor + :param tma_atom_k: TMA copy atom for key tensor + :type tma_atom_k: cute.CopyAtom + :param mK_kdl: Partitioned key tensor + :type mK_kdl: cute.Tensor + :param tma_atom_v: TMA copy atom for value tensor + :type tma_atom_v: cute.CopyAtom + :param mV_dkl: Partitioned value tensor + :type mV_dkl: cute.Tensor + :param tma_atom_o: TMA copy atom for output tensor + :type tma_atom_o: cute.CopyAtom + :param mO_qdl: Partitioned output tensor + :type mO_qdl: cute.Tensor + :param scale_softmax_log2: The log2 scale factor for softmax + :type scale_softmax_log2: Float32 + :param scale_output: The scale factor for the output + :type scale_output: Float32 + :param q_smem_layout_staged: Shared memory layout for query tensor + :type q_smem_layout_staged: cute.ComposedLayout + :param k_smem_layout_staged: Shared memory layout for key tensor + :type k_smem_layout_staged: cute.ComposedLayout + :param p_tmem_layout_staged: Tensor memory layout for probability matrix + :type p_tmem_layout_staged: cute.ComposedLayout + :param v_smem_layout_staged: Shared memory layout for value tensor + :type v_smem_layout_staged: cute.ComposedLayout + :param o_smem_layout_staged: Shared memory layout for output tensor + :type o_smem_layout_staged: cute.ComposedLayout + :param tile_sched_params: Scheduling parameters for work distribution + :type tile_sched_params: FmhaStaticTileSchedulerParams + """ + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.load_warp_id: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_q) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_k) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_v) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_o) + + # Alloc + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + load_q_producer, load_q_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.q_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_q_bytes, + barrier_storage=storage.load_q_mbar_ptr.data_ptr(), + ).make_participants() + load_kv_producer, load_kv_consumer = pipeline.PipelineTmaUmma.create( + num_stages=self.kv_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_id])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + tx_count=self.tma_copy_kv_bytes, + barrier_storage=storage.load_kv_mbar_ptr.data_ptr(), + ).make_participants() + mma_s0_producer, mma_s0_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + barrier_storage=storage.mma_s0_mbar_ptr.data_ptr(), + ).make_participants() + mma_s1_producer, mma_s1_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_softmax_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.mma_s1_mbar_ptr.data_ptr(), + ).make_participants() + s0_corr_producer, s0_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s0_corr_mbar_ptr.data_ptr(), + ).make_participants() + s1_corr_producer, s1_corr_consumer = pipeline.PipelineAsync.create( + num_stages=self.softmax_corr_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.s1_corr_mbar_ptr.data_ptr(), + ).make_participants() + corr_epi_producer, corr_epi_consumer = pipeline.PipelineAsync.create( + num_stages=self.epi_stage, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len([self.epilogue_warp_id]) + ), + barrier_storage=storage.corr_epi_mbar_ptr.data_ptr(), + ).make_participants() + mma_corr_producer, mma_corr_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=self.mma_corr_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_id])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.correction_warp_ids) + ), + barrier_storage=storage.mma_corr_mbar_ptr.data_ptr(), + ).make_participants() + s0_s1_sequence_producer, s0_s1_sequence_consumer = ( + pipeline.PipelineAsync.create( + num_stages=1, + producer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax0_warp_ids) + ), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.softmax1_warp_ids) + ), + barrier_storage=storage.s0_s1_sequence_mbar_ptr.data_ptr(), + ).make_participants() + ) + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + + # Correction & Epilogue & tmem barrier init + if warp_idx == self.empty_warp_id: + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, + self.threads_per_warp + * len( + ( + *self.softmax0_warp_ids, + *self.softmax1_warp_ids, + *self.correction_warp_ids, + ) + ), + ) + cute.arch.mbarrier_init_fence() + + # Generate smem tensor Q/K/V/O + # (MMA, MMA_Q, MMA_D, PIPE) + sQ = storage.sQ.get_tensor( + q_smem_layout_staged.outer, swizzle=q_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + sK = storage.sK.get_tensor( + k_smem_layout_staged.outer, swizzle=k_smem_layout_staged.inner + ) + # (MMA, MMA_K, MMA_D, PIPE) + # Strip swizzle info to reuse smem + sV_ptr = cute.recast_ptr(sK.iterator, v_smem_layout_staged.inner) + sV = cute.make_tensor(sV_ptr, v_smem_layout_staged.outer) + sO = storage.sO.get_tensor( + o_smem_layout_staged.outer, swizzle=o_smem_layout_staged.inner + ) + qk_thr_mma = qk_tiled_mma.get_slice(0) # default 1sm + pv_thr_mma = pv_tiled_mma.get_slice(0) # default 1sm + tSrQ = qk_thr_mma.make_fragment_A(sQ) + tSrK = qk_thr_mma.make_fragment_B(sK) + tOrV = pv_thr_mma.make_fragment_B(sV) + qk_acc_shape = qk_thr_mma.partition_shape_C( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + tStS = qk_thr_mma.make_fragment_C(qk_acc_shape) + pv_acc_shape = pv_thr_mma.partition_shape_C( + (self.pv_mma_tiler[0], self.pv_mma_tiler[1]) + ) + tOtO = pv_thr_mma.make_fragment_C(pv_acc_shape) + + tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + tStS1 = cute.make_tensor(tStS.iterator + self.tmem_s1_offset, tStS.layout) + tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) + tOtO1 = cute.make_tensor(tOtO.iterator + self.tmem_o1_offset, tOtO.layout) + + tP = cute.make_tensor(tStS.iterator, p_tmem_layout_staged.outer) + tOrP = pv_thr_mma.make_fragment_A(tP)[None, None, None, 0] + tOrP0 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, + tOrP.layout, + ) + tOrP1 = cute.make_tensor( + tOrP.iterator + + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p1_offset, + tOrP.layout, + ) + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + # /////////////////////////////////////////////////////////////////////////////// + # EMPTY + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.empty_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_empty) + + # /////////////////////////////////////////////////////////////////////////////// + # LOAD + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + mQ_qdl_ = mQ_qdl + mK_kdl_ = mK_kdl + mV_dkl_ = mV_dkl + seqlen_k = mK_kdl.shape[0] + curr_block_coord_q = curr_block_coord + curr_block_coord_kv = curr_block_coord + + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mQ = ( + mQ_qdl.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mQ_qdl_ = cute.domain_offset(logical_offset_mQ, mQ_qdl) + curr_block_coord_q = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + logical_offset_mK = ( + mK_kdl.shape[0] - seqlen_k, + 0, + (0, cuseqlen_k + seqlen_k), + ) + logical_offset_mV = ( + 0, + mK_kdl.shape[0] - seqlen_k, + (0, cuseqlen_k + seqlen_k), + ) + mK_kdl_ = cute.domain_offset(logical_offset_mK, mK_kdl) + mV_dkl_ = cute.domain_offset(logical_offset_mV, mV_dkl) + curr_block_coord_kv = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], Int32(0)), + ) + + # Local tile partition global tensors + # (bM, bK, loopM, loopK, loopL) + gQ_qdl = cute.flat_divide( + mQ_qdl_, cute.select(self.qk_mma_tiler, mode=[0, 2]) + ) + tSgQ_qdl = qk_thr_mma.partition_A(gQ_qdl) + tQsQ, tQgQ_qdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_q, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sQ, 0, 3), + cute.group_modes(tSgQ_qdl, 0, 3), + ) + tQgQ = tQgQ_qdl[None, None, 0, curr_block_coord_q[2]] + + gK_kdl = cute.flat_divide( + mK_kdl_, cute.select(self.qk_mma_tiler, mode=[1, 2]) + ) + tSgK_kdl = qk_thr_mma.partition_B(gK_kdl) + tKsK, tKgK_kdl = cute.nvgpu.cpasync.tma_partition( + tma_atom_k, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sK, 0, 3), + cute.group_modes(tSgK_kdl, 0, 3), + ) + tKgK = tKgK_kdl[None, None, 0, curr_block_coord_kv[2]] + + gV_dkl = cute.flat_divide( + mV_dkl_, cute.select(self.pv_mma_tiler, mode=[1, 2]) + ) + tSgV_dkl = pv_thr_mma.partition_B(gV_dkl) + tVsV, tVgV_dkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_v, + 0, # no multicast + cute.make_layout(1), + cute.group_modes(sV, 0, 3), + cute.group_modes(tSgV_dkl, 0, 3), + ) + tVgV = tVgV_dkl[None, 0, None, curr_block_coord_kv[2]] + + # Q0 + q0_coord = 2 * curr_block_coord_q[0] + q0_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q0_coord], + tQsQ[None, q0_handle.index], + tma_bar_ptr=q0_handle.barrier, + ) + # K0 + kv_coord = 0 # seqlen_kv_loop + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Q1 + q1_coord = q0_coord + 1 + q1_handle = load_q_producer.acquire_and_advance() + cute.copy( + tma_atom_q, + tQgQ[None, q1_coord], + tQsQ[None, q1_handle.index], + tma_bar_ptr=q1_handle.barrier, + ) + # V0 + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + kv_coord += 1 + + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # Ki + k_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_k, + tKgK[None, kv_coord], + tKsK[None, k_handle.index], + tma_bar_ptr=k_handle.barrier, + ) + # Vi + v_handle = load_kv_producer.acquire_and_advance() + cute.copy( + tma_atom_v, + tVgV[None, kv_coord], + tVsV[None, v_handle.index], + tma_bar_ptr=v_handle.barrier, + ) + kv_coord += 1 + # End of seqlen_kv loop + + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # /////////////////////////////////////////////////////////////////////////////// + # MMA + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # Alloc tmem buffer + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + cute.arch.alloc_tmem(tmem_alloc_cols, storage.tmem_holding_buf) + cute.arch.barrier( + barrier_id=self.tmem_alloc_sync_bar_id, + number_of_threads=self.threads_per_warp, + ) + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + seqlen_k = mK_kdl.shape[0] + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + + # GEMM_QK00 (Q0 * K0 -> S0) + # 1. wait for Q0 + q0_handle = load_q_consumer.wait_and_advance() + tSrQ0 = tSrQ[None, None, None, q0_handle.index] + # 2. wait for K0 + k_handle = load_kv_consumer.wait_and_advance() + tSrK0 = tSrK[None, None, None, k_handle.index] + # 3. acquire empty S0 buffer + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_0 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS0, + tSrQ0[kphase_coord_0], + tSrK0[kphase_coord_0], + tStS0, + ) + # 5. release S0 + s0_handle.commit() + # End of GEMM (Q0 * K0 -> S0) + + # GEMM_QK10 (Q1 * K0 -> S1), K0 is ready in GEMM_QK00 + # 1. wait for Q1 + q1_handle = load_q_consumer.wait_and_advance() + tSrQ1 = tSrQ[None, None, None, q1_handle.index] + # 2. acquire empty S1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_1 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS1, + tSrQ1[kphase_coord_1], + tSrK0[kphase_coord_1], + tStS1, + ) + # 4. release S1 + s1_handle.commit() + # 5. release K0 + k_handle.release() + # End of GEMM (Q1 * K0 -> S1) + # Note: Q0 & Q1 are still needed in the seqlen_kv loop + # so we need to release them after the seqlen_kv loop + + # GEMM_PV00 (P0 * V0 -> O0_partial), O0 needs to be accumulated in the seqlen_kv loop + # 1. wait for V0 + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] + # 2. acquire corrected O0_partial + # Note: acquire corr first to take it out of the critical + # path since softmax takes longer + o0_handle = mma_corr_producer.acquire_and_advance() + # 3. acquire P0 + # this acquire returns the ownership of all of S0 to the mma warp + # including the P0 part (inplaced in S0) + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_2 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + pv_tiled_mma, + tOtO0, + tOrP0[kphase_coord_2], + tOrVi[kphase_coord_2], + tOtO0, + ) + # 5. release accumulated O0_partial + o0_handle.commit() + # End of GEMM_PV00 (P0 * V0 -> O0_partial) + + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + + # O1 hasn't been accumulated yet, its first MMA calculation doesn't need to accumulate + pv_whether_acc = False + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # GEMM_QK0i (Q0 * Ki -> S0) + # 1. wait for Ki + k_handle = load_kv_consumer.wait_and_advance() + tSrKi = tSrK[None, None, None, k_handle.index] + # 2. gemm + inner_num_kphases = cute.size(tSrQ0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_3 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS0, + tSrQ0[kphase_coord_3], + tSrKi[kphase_coord_3], + tStS0, + ) + # 3. release S0 + s0_handle.commit() + # End of GEMM_QK0i (Q0 * Ki -> S0) + + # GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial), V(i-1) is ready in GEMM_PV0(i-1) + # 1. acquire corrected O1_partial + o1_handle = mma_corr_producer.acquire_and_advance() + # 2. acquire P1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_4 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, pv_whether_acc) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_4], + tOrVi[kphase_coord_4], + tOtO1, + ) + pv_whether_acc = True + # 4. release accumulated O1_partial + o1_handle.commit() + # 5. release V(i-1) + v_handle.release() + # End of GEMM_PV1(i-1) (P1 * V(i-1) -> O1_partial) + + # GEMM_QK1i (Q1 * Ki -> S1), Q1 is ready in GEMM_QK10; Ki is ready in GEMM_QK0i + # 1. gemm + inner_num_kphases = cute.size(tSrQ1, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_5 = (None, None, kphase_idx) + qk_tiled_mma.set(tcgen05.Field.ACCUMULATE, kphase_idx != 0) + cute.gemm( + qk_tiled_mma, + tStS1, + tSrQ1[kphase_coord_5], + tSrKi[kphase_coord_5], + tStS1, + ) + s1_handle.commit() + # 2. release Ki + k_handle.release() + # End of GEMM_QK1i (Q1 * Ki -> S1) + + # GEMM_PV0i (P0 * Vi -> O0_partial) + # 1. wait for Vi + v_handle = load_kv_consumer.wait_and_advance() + tOrVi = tOrV[None, None, None, v_handle.index] + # 2. acquire corrected O0_partial + o0_handle = mma_corr_producer.acquire_and_advance() + # 3. acquire P0 + s0_handle = mma_s0_producer.acquire_and_advance() + # 4. gemm + inner_num_kphases = cute.size(tOrP0, mode=[2]) + for kphase_idx in cutlass.range( + inner_num_kphases, unroll_full=True + ): + kphase_coord_6 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO0, + tOrP0[kphase_coord_6], + tOrVi[kphase_coord_6], + tOtO0, + ) + # 5. release accumulated O0_partial + o0_handle.commit() + # End of GEMM_PV0i (P0 * Vi -> O0_partial) + # End of seqlen_kv loop + + # release Q0 & Q1 + q0_handle.release() + q1_handle.release() + + # GEMM_PV1(i_end) (P1 * Vi_end -> O1) + # 1. acquire corrected O1_partial + o1_handle = mma_corr_producer.acquire_and_advance() + # 2. acquire P1 + s1_handle = mma_s1_producer.acquire_and_advance() + # 3. gemm + num_kphases = cute.size(tOrP1, mode=[2]) + for kphase_idx in cutlass.range(num_kphases, unroll_full=True): + kphase_coord_7 = (None, None, kphase_idx) + pv_tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + cute.gemm( + pv_tiled_mma, + tOtO1, + tOrP1[kphase_coord_7], + tOrVi[kphase_coord_7], + tOtO1, + ) + # 4. commit accumulated O1 + o1_handle.commit() + # 5. release Vi_end + v_handle.release() + # End of GEMM_PV1(i_end) (P1 * Vi_end -> O1) + + # Commit S0 and S1 + s0_handle.commit() + s1_handle.commit() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # dealloc tmem buffer + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + tmem_alloc_cols = Int32(self.tmem_alloc_cols) + # Retrieving tmem ptr and make acc + tmem_ptr = cute.arch.retrieve_tmem_ptr( + Float32, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + cute.arch.dealloc_tmem(tmem_ptr, tmem_alloc_cols) + + # /////////////////////////////////////////////////////////////////////////////// + # Epilogue + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx == self.epilogue_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + continue_cond = False + cuseqlen_q = Int32(0) + seqlen_q = mQ_qdl.shape[0] + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + if not continue_cond: + curr_block_coord_o = curr_block_coord + mO_qdl_ = mO_qdl + if cutlass.const_expr(cum_seqlen_q is not None): + logical_offset_mO = ( + mO_qdl_.shape[0] - seqlen_q, + 0, + (0, cuseqlen_q + seqlen_q), + ) + mO_qdl_ = cute.domain_offset(logical_offset_mO, mO_qdl_) + curr_block_coord_o = ( + curr_block_coord[0], + curr_block_coord[1], + (curr_block_coord[2][0], 0), + ) + + o0_coord = 2 * curr_block_coord_o[0] + o1_coord = o0_coord + 1 + gO_qdl = cute.flat_divide( + mO_qdl_, cute.select(self.pv_mma_tiler, mode=[0, 1]) + ) + gO = gO_qdl[None, None, None, 0, curr_block_coord_o[2]] + tOsO, tOgO = cute.nvgpu.cpasync.tma_partition( + tma_atom_o, + 0, + cute.make_layout(1), + cute.group_modes(sO, 0, 2), + cute.group_modes(gO, 0, 2), + ) + + # O0 O1 using the same pipeline + # wait from corr, issue tma store on smem + # O0 + # 1. wait for O0 final + o0_handle = corr_epi_consumer.wait_and_advance() + # 2. copy O0 to gmem + cute.copy(tma_atom_o, tOsO[None, 0], tOgO[None, o0_coord]) + cute.arch.cp_async_bulk_commit_group() + # O1 + # 1. wait for O1 final + o1_handle = corr_epi_consumer.wait_and_advance() + # 2. copy O1 to gmem + cute.copy(tma_atom_o, tOsO[None, 1], tOgO[None, o1_coord]) + cute.arch.cp_async_bulk_commit_group() + + # Ensure O0 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(1, read=True) + o0_handle.release() + # Ensure O1 buffer is ready to be released + cute.arch.cp_async_bulk_wait_group(0, read=True) + o1_handle.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax0 + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx < self.softmax1_warp_ids[0]: + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + self.softmax( + stage=0, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + qk_thr_mma=qk_thr_mma, + tStS=tStS, + tStSi=tStS0, + mma_si_consumer=mma_s0_consumer, + si_corr_producer=s0_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Softmax1 + # /////////////////////////////////////////////////////////////////////////////// + if ( + warp_idx < self.correction_warp_ids[0] + and warp_idx >= self.softmax1_warp_ids[0] + ): + # increase register after decreasing + cute.arch.warpgroup_reg_alloc(self.num_regs_softmax) + + self.softmax( + stage=1, + seqlen_k=mK_kdl.shape[0], + cum_seqlen_q=cum_seqlen_q, + cum_seqlen_k=cum_seqlen_k, + scale_softmax_log2=scale_softmax_log2, + qk_thr_mma=qk_thr_mma, + tStS=tStS, + tStSi=tStS1, + mma_si_consumer=mma_s1_consumer, + si_corr_producer=s1_corr_producer, + s0_s1_sequence_consumer=s0_s1_sequence_consumer, + s0_s1_sequence_producer=s0_s1_sequence_producer, + tile_sched_params=tile_sched_params, + ) + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + + # /////////////////////////////////////////////////////////////////////////////// + # Correction + # /////////////////////////////////////////////////////////////////////////////// + if warp_idx >= self.correction_warp_ids[0] and warp_idx < self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_correction) + + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr_mma.partition_C(cS) + + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + + tStS_vec0 = cute.make_tensor( + tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout + ) + tStS_vec1 = cute.make_tensor( + tStS.iterator + self.tmem_vec1_offset, tStS_vec_layout + ) + + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tmem_load_v_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + + tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_v_atom, tStS_vec0) + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(thread_idx) + + tTMEM_LOAD_VECtS0 = thr_tmem_load_vec.partition_S(tStS_vec0) + tTMEM_LOAD_VECtS1 = thr_tmem_load_vec.partition_S(tStS_vec1) + tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_k = mK_kdl.shape[0] + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + # Ignore first signal from softmax as no correction is required + vec0_handle = s0_corr_consumer.wait_and_advance() + vec0_handle.release() + vec1_handle = s1_corr_consumer.wait_and_advance() + seqlen_kv_loop_steps = ( + self.get_trip_count(curr_block_coord, self.cta_tiler, seqlen_k) + - 1 + ) + for i in cutlass.range(0, seqlen_kv_loop_steps, 1, unroll=1): + # wait for vec0 (row_wise current max & previous max) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS + ) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.math.exp2(scale_, fastmath=True) + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() + self.correction_rescale(pv_thr_mma, tOtO0, scale) + # release vec1 & o0 + vec1_handle.release() + cute.arch.fence_view_async_tmem_store() + o0_handle.release() + + # wait for vec1 (row_wise current max & previous max) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy( + tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS + ) + scale_ = scale_softmax_log2 * ( + tTMEM_LOAD_VECrS[0] - tTMEM_LOAD_VECrS[1] + ) + scale = cute.math.exp2(scale_, fastmath=True) + o1_handle = mma_corr_consumer.wait_and_advance() + self.correction_rescale(pv_thr_mma, tOtO1, scale) + vec0_handle.release() + cute.arch.fence_view_async_tmem_store() + o1_handle.release() + # End of seqlen_corr_loop_steps + vec1_handle.release() + + # wait for vec0 (row_wise global sum) + vec0_handle = s0_corr_consumer.wait_and_advance() + tTMEM_LOAD_VECrS = cute.make_fragment( + tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype + ) + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS0, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec0_handle.release() + # wait for o0 + o0_handle = mma_corr_consumer.wait_and_advance() + o0_final_handle = corr_epi_producer.acquire_and_advance() + self.correction_epilog( + pv_thr_mma, + tOtO0, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 0], + ) + o0_handle.release() + o0_final_handle.commit() + + # wait for vec1 (row_wise global sum) + vec1_handle = s1_corr_consumer.wait_and_advance() + cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS1, tTMEM_LOAD_VECrS) + cute.arch.fence_view_async_tmem_load() + vec1_handle.release() + # wait for o1 + o1_handle = mma_corr_consumer.wait_and_advance() + o1_final_handle = corr_epi_producer.acquire_and_advance() + self.correction_epilog( + pv_thr_mma, + tOtO1, + scale_output / tTMEM_LOAD_VECrS[0], + sO[None, None, 1], + ) + o1_handle.release() + o1_final_handle.commit() + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr) + return + + @cute.jit + def softmax_step( + self, + stage: int, + need_apply_mask: bool, + iter_args: tuple, + value_args: tuple, + pipeline_args: tuple, + atom_args: tuple, + tensor_args: tuple, + ) -> Tuple[ + Float32, + Float32, + pipeline.PipelineProducer.ImmutableResourceHandle, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, + pipeline.PipelineConsumer, + pipeline.PipelineProducer, + ]: + """Perform a single step of the softmax computation on a block of attention scores. + + This method processes one block of the attention matrix, computing numerically stable + softmax by first finding the row maximum, subtracting it from all elements, applying + exponential function, and then normalizing by the sum of exponentials. It also handles + optional masking of attention scores. + + The method involves several key operations: + 1. Loading attention scores from tensor memory + 2. Applying optional masking based on position + 3. Computing row-wise maximum values for numerical stability + 4. Transforming scores using exp2(x*scale - max*scale) + 5. Computing row sums for normalization + 6. Coordinating pipeline synchronization between different processing stages + + :param stage: Processing stage (0 for first half, 1 for second half) + :type stage: int + :param need_apply_mask: Whether to apply attention masking + :type need_apply_mask: bool + :param iter_args: Tuple containing the counting tensor, row_max, row_sum, and vector buffer's handle for current iteration + :type iter_args: tuple + :param value_args: Tuple containing seqlen_k and scale_softmax_log2 + :type value_args: tuple + :param pipeline_args: Tuple containing pipeline related arguments for MMA, correction, and sequence synchronization + :type pipeline_args: tuple + :param atom_args: Tuple containing mma & copy atoms + :type atom_args: tuple + :param tensor_args: Tuple containing softmax related tensors + :type tensor_args: tuple + :return: Updated state values (row_max, row_sum, and pipeline related arguments) + :rtype: tuple + """ + cS, row_max, row_sum, vec_i_handle = iter_args + seqlen_k, scale_softmax_log2 = value_args + ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = pipeline_args + ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) = atom_args + ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) = tensor_args + + tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + + tScS_P_layout = cute.composition( + tScS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout) + tTMEM_LOADcS = thr_tmem_load.partition_D(tScS) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tTMEM_STOREcS = thr_tmem_store.partition_S(tScS_P) + + # Wait for Si + si_handle = mma_si_consumer.wait_and_advance() + tTMEM_LOADrS = cute.make_fragment(tTMEM_LOADcS.shape, self.qk_acc_dtype) + cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS) + if need_apply_mask: + self.apply_mask(tTMEM_LOADrS, tTMEM_LOADcS, seqlen_k) + + old_row_max = row_max + row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0) + row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + row_max_safe = 0.0 + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = old_row_max + tTMEM_STORE_VECrS[1] = row_max_safe + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + # Notify correction wg that row_max is ready + vec_i_handle.commit() + + tTMEM_STORErS_x4 = cute.make_fragment(tTMEM_STOREcS.shape, self.qk_acc_dtype) + tTMEM_STORErS_x4_e = cute.make_tensor( + cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype), + tTMEM_LOADrS.layout, + ) + + scale = scale_softmax_log2 + minus_row_max_scale = (0.0 - row_max_safe) * scale + + # Sequence barrier wait + if cutlass.const_expr(stage == 0): + sequence_producer_handle = s0_s1_sequence_producer.acquire_and_advance() + else: + sequence_consumer_handle = s0_s1_sequence_consumer.wait_and_advance() + frg_cnt = 4 + frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + tTMEM_STORErS_x4_e_frg = cute.logical_divide( + tTMEM_STORErS_x4_e, cute.make_layout(frg_tile) + ) + for j in range(frg_cnt): + for k in range(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j] = ( + cute.arch.fma_packed_f32x2( + (tTMEM_LOADrS_frg[k, j], tTMEM_LOADrS_frg[k + 1, j]), + (scale, scale), + (minus_row_max_scale, minus_row_max_scale), + ) + ) + tTMEM_LOADrS_frg[k, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k, j], fastmath=True + ) + tTMEM_LOADrS_frg[k + 1, j] = cute.math.exp2( + tTMEM_LOADrS_frg[k + 1, j], fastmath=True + ) + s_vec = tTMEM_LOADrS_frg[None, j].load() + tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype)) + # Sequence barrier arrive + if cutlass.const_expr(stage == 0): + sequence_producer_handle.commit() + else: + sequence_consumer_handle.release() + cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4) + cute.arch.fence_view_async_tmem_store() + # Notify tensor core warp that softmax(S->P) is ready + si_handle.release() + + vec_i_handle = si_corr_producer.acquire_and_advance() + acc_scale_ = scale * (old_row_max - row_max_safe) + acc_scale = cute.math.exp2(acc_scale_, fastmath=True) * 0.5 + row_sum *= acc_scale + local_row_sum_0 = (row_sum, row_sum) + local_row_sum_1 = (0.0, 0.0) + local_row_sum_2 = (0.0, 0.0) + local_row_sum_3 = (0.0, 0.0) + + reduction_unroll = 4 + frg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll + tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + + for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_frg, mode=[0]), 2): + local_row_sum_0 = cute.arch.add_packed_f32x2( + local_row_sum_0, (tTMEM_LOADrS_frg[j, 0], tTMEM_LOADrS_frg[j + 1, 0]) + ) + local_row_sum_1 = cute.arch.add_packed_f32x2( + local_row_sum_1, (tTMEM_LOADrS_frg[j, 1], tTMEM_LOADrS_frg[j + 1, 1]) + ) + local_row_sum_2 = cute.arch.add_packed_f32x2( + local_row_sum_2, (tTMEM_LOADrS_frg[j, 2], tTMEM_LOADrS_frg[j + 1, 2]) + ) + local_row_sum_3 = cute.arch.add_packed_f32x2( + local_row_sum_3, (tTMEM_LOADrS_frg[j, 3], tTMEM_LOADrS_frg[j + 1, 3]) + ) + + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1) + local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3) + local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_2) + row_sum = local_row_sum_0[0] + local_row_sum_0[1] + + return ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + + # For both softmax0 and softmax1 warp group + @cute.jit + def softmax( + self, + stage: int, + seqlen_k: Int32, + cum_seqlen_q: cute.Tensor | None, + cum_seqlen_k: cute.Tensor | None, + scale_softmax_log2: Float32, + qk_thr_mma: cute.core.ThrMma, + tStS: cute.Tensor, + tStSi: cute.Tensor, + mma_si_consumer: pipeline.PipelineConsumer, + si_corr_producer: pipeline.PipelineProducer, + s0_s1_sequence_consumer: pipeline.PipelineConsumer, + s0_s1_sequence_producer: pipeline.PipelineProducer, + tile_sched_params: FmhaStaticTileSchedulerParams, + ): + """Compute softmax on attention scores from QK matrix multiplication. + + This method handles the softmax computation for either the first or second half of the + attention matrix, depending on the 'stage' parameter. It calculates row-wise maximum + and sum values needed for stable softmax computation, applies optional masking, and + transforms raw attention scores into probability distributions. + + The implementation uses specialized memory access patterns and efficient math operations + for computing exp(x) using exp2 functions. It also coordinates pipeline + synchronization between MMA, correction, and sequence processing stages. + + :param stage: Processing stage (0 for first half, 1 for second half of attention matrix) + :type stage: int + :param scale_softmax_log2: Log2 scale factor for softmax operation + :type scale_softmax_log2: Float32 + :param qk_thr_mma: Thread MMA operation for QK matrix multiplication + :type qk_thr_mma: cute.core.ThrMma + :param tStS: Shared tensor for softmax input/output + :type tStS: cute.Tensor + :param tStSi: Input tensor containing attention scores + :type tStSi: cute.Tensor + :param mma_si_pipeline: Pipeline for synchronizing with MMA operations + :type mma_si_pipeline: pipeline.PipelineAsync + :param si_corr_pipeline: Pipeline for synchronizing with correction operations + :type si_corr_pipeline: pipeline.PipelineAsync + :param s0_s1_sequence_pipeline: Pipeline for synchronizing between stage 0 and 1 + :type s0_s1_sequence_pipeline: pipeline.PipelineAsync + :param tile_sched_params: Parameters for tile scheduling + :type tile_sched_params: FmhaStaticTileSchedulerParams + """ + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + + cS_base = cute.make_identity_tensor( + (self.qk_mma_tiler[0], self.qk_mma_tiler[1]) + ) + tilePlikeFP32 = self.qk_mma_tiler[1] // 32 * self.o_dtype.width + tScS = qk_thr_mma.partition_C(cS_base) + tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) + tmem_vec_offset = self.tmem_vec0_offset if stage == 0 else self.tmem_vec1_offset + tStS_vec = cute.make_tensor(tStS.iterator + tmem_vec_offset, tStS_vec_layout) + tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2))) + tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout) + tStS_P_layout = cute.composition( + tStS.layout, cute.make_layout((128, tilePlikeFP32)) + ) + tmem_p_offset = self.tmem_p0_offset if stage == 0 else self.tmem_p1_offset + tStS_P = cute.make_tensor(tStS.iterator + tmem_p_offset, tStS_P_layout) + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStSi) + thread_idx = tidx % ( + self.threads_per_warp + * ( + len(self.softmax0_warp_ids) + if stage == 0 + else len(self.softmax1_warp_ids) + ) + ) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + tTMEM_LOADtS = thr_tmem_load.partition_S(tStSi) + tmem_store_vec_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), + self.qk_acc_dtype, + ) + tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec) + thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(thread_idx) + tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec) + tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.qk_acc_dtype, + ) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + tTMEM_STOREtS_x4 = thr_tmem_store.partition_D(tStS_P) + + tile_sched = create_fmha_static_tile_scheduler( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + while work_tile.is_valid_tile: + curr_block_coord = work_tile.tile_idx + batch_coord = curr_block_coord[2][1] + seqlen_k_ = seqlen_k + continue_cond = False + + if cutlass.const_expr(cum_seqlen_q is not None): + cuseqlen_q = cum_seqlen_q[batch_coord] + seqlen_q = cum_seqlen_q[batch_coord + 1] - cuseqlen_q + continue_cond = ( + not FmhaStaticTileScheduler.check_valid_work_for_seqlen_q( + self.cta_tiler[0], + curr_block_coord[0], + seqlen_q, + ) + ) + + if not continue_cond: + if cutlass.const_expr(cum_seqlen_k is not None): + cuseqlen_k = cum_seqlen_k[batch_coord] + seqlen_k_ = cum_seqlen_k[batch_coord + 1] - cuseqlen_k + row_max = -Float32.inf + row_sum = 0.0 + value_args = (seqlen_k_, scale_softmax_log2) + atom_args = ( + qk_thr_mma, + tiled_tmem_load, + tiled_tmem_store, + tiled_tmem_store_vec, + thr_tmem_load, + thr_tmem_store, + thr_tmem_store_vec, + ) + tensor_args = ( + tTMEM_LOADtS, + tTMEM_STORE_VECtS, + tTMEM_STOREtS_x4, + ) + + logical_offset = ( + curr_block_coord[0] * self.cta_tiler[0] + + stage * self.qk_mma_tiler[0], + 0, + ) + cS = cute.domain_offset(logical_offset, cS_base) + vec_i_handle = si_corr_producer.acquire_and_advance() + unmask_count = self.get_unmasked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) + for i in cutlass.range(0, unmask_count, 1, unroll=1): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.softmax_step( + stage, + False, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + ) + mask_count = self.get_masked_trip_count( + curr_block_coord, + self.cta_tiler, + seqlen_k_, + ) + + for i in cutlass.range( + unmask_count, unmask_count + mask_count, 1, unroll=1 + ): + cS_iter = cute.domain_offset((0, i * self.qk_mma_tiler[1]), cS) + iter_args = (cS_iter, row_max, row_sum, vec_i_handle) + pipeline_args = ( + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) + ( + row_max, + row_sum, + vec_i_handle, + mma_si_consumer, + si_corr_producer, + s0_s1_sequence_consumer, + s0_s1_sequence_producer, + ) = self.softmax_step( + stage, + True, + iter_args, + value_args, + pipeline_args, + atom_args, + tensor_args, + ) + si_handle = mma_si_consumer.wait_and_advance() + tTMEM_STORE_VECrS = cute.make_fragment( + tTMEM_STORE_VECcS.shape, self.qk_acc_dtype + ) + tTMEM_STORE_VECrS[0] = row_sum + tTMEM_STORE_VECrS[1] = row_max + cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS) + cute.arch.fence_view_async_tmem_store() + vec_i_handle.commit() + si_corr_producer.acquire() + # Empty step to sync against pipe s + si_handle.release() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # End of persistent scheduler loop + + @cute.jit + def correction_rescale( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + ): + """Rescale intermediate attention results based on softmax normalization factor. + + This method performs a crucial correction step in the attention computation pipeline. + When processing attention in blocks, the softmax normalization factors may change + as new blocks are processed. This method rescales previously computed partial + output values to account for updated normalization factors. + + The implementation uses efficient tensor memory operations to: + 1. Load existing partial attention output from tensor memory + 2. Apply the scaling factor to all elements + 3. Store the rescaled results back to tensor memory + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor representing partial attention output to be rescaled + :type tOtO: cute.Tensor + :param scale: Scaling factor to apply to the partial results + :type scale: Float32 + """ + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + tOcO = thr_mma.partition_C(cO) + + corr_tile_size = 16 # tuneable parameter + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), + self.pv_acc_dtype, + ) + + tOtO_i_layout = cute.composition( + tOtO.layout, cute.make_layout((128, corr_tile_size)) + ) + tOcO_i_layout = cute.composition( + tOcO.layout, cute.make_layout((128, corr_tile_size)) + ) + + tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout) + tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout) + + tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tOtO_i) + tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tOtO_i) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + thr_tmem_store = tiled_tmem_store.get_slice(thread_idx) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i) + tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i) + + tTMEM_STOREtO = thr_tmem_store.partition_D(tOtO_i) + + tTMrO = cute.make_fragment( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.pv_acc_dtype + ) + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout + ) + + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i) + for j in range(0, cute.size(tTMrO_i), 2): + tTMrO_i[j], tTMrO_i[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO_i[j], tTMrO_i[j + 1]), + (scale, scale), + ) + cute.copy(tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i) + + @cute.jit + def correction_epilog( + self, + thr_mma: cute.core.ThrMma, + tOtO: cute.Tensor, + scale: Float32, + sO: cute.Tensor, + ): + """Apply final scaling and transformation to attention output before writing to global memory. + + This correction_epilog function handles the final processing step for attention output values. + It applies a scaling factor to the accumulated attention results and prepares the + data for efficient transfer back to global memory. + + The method performs: + 1. Loading of accumulated attention results from tensor memory + 2. Application of the final output scaling factor + 3. Type conversion if necessary (typically from higher precision accumulator to output precision) + 4. Reorganization of data for optimal memory access patterns + 5. Preparation for efficient TMA store operations + + :param thr_mma: Thread MMA operation for the computation + :type thr_mma: cute.core.ThrMma + :param tOtO: Tensor containing accumulated attention output + :type tOtO: cute.Tensor + :param scale: Final scaling factor to apply to the output + :type scale: Float32 + :param sO: Shared memory tensor for the final output + :type sO: cute.Tensor + """ + + pv_tiled_mma_shape = ( + self.pv_mma_tiler[0], + self.pv_mma_tiler[1], + ) + cO = cute.make_identity_tensor(pv_tiled_mma_shape) + + corr_tile_size = 32 * 8 // self.o_dtype.width + tOsO = thr_mma.partition_C(sO) + tOcO = thr_mma.partition_C(cO) + + tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size))) + tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_tile_size))) + tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size))) + tidx, _, _ = cute.arch.thread_idx() + thread_idx = tidx % (self.threads_per_warp * len(self.correction_warp_ids)) + + epi_subtile = (self.epi_tile[0], corr_tile_size) + tmem_copy_atom = sm100_utils.get_tmem_load_op( + self.pv_mma_tiler, + self.o_layout, + self.o_dtype, + self.pv_acc_dtype, + epi_subtile, + use_2cta_instrs=False, + ) + + tiled_tmem_load = tcgen05.make_tmem_copy( + tmem_copy_atom, tOtO_i[(None, None), 0] + ) + + thr_tmem_load = tiled_tmem_load.get_slice(thread_idx) + smem_copy_atom = sm100_utils.get_smem_store_op( + self.o_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load + ) + tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load) + + tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None]) + tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None]) + tTMEM_LOADoO = thr_tmem_load.partition_D(tOcO_i[(None, None), None]) + + for i in range(self.cta_tiler[2] // corr_tile_size): + tTMEM_LOADtO_i = tTMEM_LOADtO[None, 0, 0, i] + tTMEM_LOADsO_i = tTMEM_LOADsO[None, 0, 0, i] + tTMrO = cute.make_fragment( + tTMEM_LOADoO[None, 0, 0, i].shape, self.pv_acc_dtype + ) + cute.copy(tiled_tmem_load, tTMEM_LOADtO_i, tTMrO) + for j in range(0, cute.size(tTMrO), 2): + tTMrO[j], tTMrO[j + 1] = cute.arch.mul_packed_f32x2( + (tTMrO[j], tTMrO[j + 1]), + (scale, scale), + ) + tSMrO = cute.make_fragment(tTMrO.shape, self.o_dtype) + o_vec = tTMrO.load() + tSMrO.store(o_vec.to(self.o_dtype)) + cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO_i) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + def get_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if ( + self.mask_type == MaskType.NO_MASK + or self.mask_type == MaskType.RESIDUAL_MASK + ): + result = cute.ceil_div(seqlen_k, tile_shape[1]) + elif self.mask_type == MaskType.CAUSAL_MASK: + max_blocks_k = cute.ceil_div(seqlen_k, tile_shape[1]) + max_blocks_q = cute.ceil_div( + (blk_coord[0] + 1) * tile_shape[0], tile_shape[1] + ) + result = cutlass.min(max_blocks_k, max_blocks_q) + return result + + @cute.jit + def get_masked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = 0 + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = 1 + else: + result = 0 + elif self.mask_type == MaskType.CAUSAL_MASK: + trip_count = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + result = cutlass.min( + trip_count, + cute.ceil_div(tile_shape[0], tile_shape[1]), + ) + return result + + @cute.jit + def get_unmasked_trip_count( + self, + blk_coord: cute.Coord, + tile_shape: cute.Shape, + seqlen_k: Int32, + ) -> Int32: + result = 0 + if self.mask_type == MaskType.NO_MASK: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.RESIDUAL_MASK: + if seqlen_k % tile_shape[1] != 0: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) - 1 + else: + result = self.get_trip_count(blk_coord, tile_shape, seqlen_k) + elif self.mask_type == MaskType.CAUSAL_MASK: + result = self.get_trip_count( + blk_coord, tile_shape, seqlen_k + ) - self.get_masked_trip_count(blk_coord, tile_shape, seqlen_k) + return result + + @cute.jit + def apply_mask( + self, + acc_qk: cute.Tensor, + index_qk: cute.Tensor, + seqlen_k: Int32, + ): + if self.mask_type == MaskType.RESIDUAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + elif self.mask_type == MaskType.CAUSAL_MASK: + for i in range(cute.size(acc_qk)): + pos = index_qk[i] + if pos[0] < pos[1] or pos[1] >= seqlen_k: + acc_qk[i] = -Float32.inf + + @staticmethod + def _compute_grid( + o_shape: cute.Shape, + cta_tiler: Tuple[int, int, int], + is_persistent: bool, + ) -> Tuple[FmhaStaticTileSchedulerParams, Tuple[int, int, int]]: + tile_sched_params = create_fmha_static_tile_scheduler_params( + is_persistent, + ( + cute.ceil_div(cute.size(o_shape[0]), cta_tiler[0]), + cute.size(o_shape[2][0]), + cute.size(o_shape[2][1]), + ), + ) + grid = FmhaStaticTileScheduler.get_grid_shape(tile_sched_params) + return tile_sched_params, grid + + +def run( + q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int], + in_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + qk_acc_dtype: Type[cutlass.Numeric], + pv_acc_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + is_persistent: bool, + is_causal: bool, + scale_q: float, + scale_k: float, + scale_v: float, + inv_scale_o: float, + scale_softmax: float, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Execute Fused Multi-Head Attention (FMHA) on Blackwell architecture and validate results. + + This function creates random input tensors for query, key, and value, then performs the + complete FMHA computation pipeline. It supports configurable data types, tiling parameters, + and various attention masking options. Results can be validated against a PyTorch reference + implementation or run multiple times for performance measurement. + + The implementation leverages specialized tensor memory operations and efficient math + operations optimized for Blackwell architecture, including pipelined computation stages + for maximum throughput. + + :param q_shape: Query tensor shape (B, S_q, H, D) where B=batch size, S_q=query sequence length, + H=number of heads, D=head dimension. + If S_q is a tuple, it is the variable sequence length. + :type q_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param k_shape: Key tensor shape (B, S_k, H_k, D) where B=batch size, S_k=key sequence length, + H_k=number of key heads (H must be divisible by H_k), D=head dimension. + If S_k is a tuple, it is the variable sequence length. + :type k_shape: Tuple[int, int, int, int] | Tuple[int, Tuple[int, ...], int, int] + :param in_dtype: Input data type for query, key and value tensors + :type in_dtype: Type[cutlass.Numeric] + :param out_dtype: Output data type for attention output + :type out_dtype: Type[cutlass.Numeric] + :param qk_acc_dtype: Accumulator data type for query-key matrix multiplication + :type qk_acc_dtype: Type[cutlass.Numeric] + :param pv_acc_dtype: Accumulator data type for probability-value matrix multiplication + :type pv_acc_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: Matrix multiply accumulate tile shape (M, N) + :type mma_tiler_mn: Tuple[int, int] + :param is_persistent: Whether to use persistent kernel optimization + :type is_persistent: bool + :param is_causal: Whether to apply causal masking + :type is_causal: bool + :param scale_q: Scaling factor for query tensor + :type scale_q: float + :param scale_k: Scaling factor for key tensor + :type scale_k: float + :param scale_v: Scaling factor for value tensor + :type scale_v: float + :param inv_scale_o: Inverse scaling factor for output tensor + :type inv_scale_o: float + :param scale_softmax: Attention score scaling factor (defaults to 1/sqrt(D) if set to 0) + :type scale_softmax: float + :param tolerance: Maximum acceptable error for validation + :type tolerance: float + :param warmup_iterations: Number of warmup iterations + :type warmup_iterations: int + :param iterations: Number of iterations to run for performance testing + :type iterations: int + :param skip_ref_check: Skip validation against reference implementation + :type skip_ref_check: bool + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache + :type use_cold_l2: bool + + :raises ValueError: If input shapes are incompatible or head dimension is unsupported + :raises RuntimeError: If GPU is unavailable for computation + :return: Execution time of the FMHA kernel in microseconds + :rtype: float + """ + + print(f"Running Blackwell SM100 FMHA test with:") + print(f" q_shape: {q_shape}") + print(f" k_shape: {k_shape}") + print(f" in_dtype: {in_dtype}") + print(f" out_dtype: {out_dtype}") + print(f" qk_acc_dtype: {qk_acc_dtype}") + print(f" pv_acc_dtype: {pv_acc_dtype}") + print(f" mma_tiler_mn: {mma_tiler_mn}") + print(f" is_persistent: {is_persistent}") + print(f" is_causal: {is_causal}") + print(f" scale_q: {scale_q}") + print(f" scale_k: {scale_k}") + print(f" scale_v: {scale_v}") + print(f" inv_scale_o: {inv_scale_o}") + print(f" scale_softmax: {scale_softmax}") + print(f" tolerance: {tolerance}") + print(f" warmup_iterations: {warmup_iterations}") + print(f" iterations: {iterations}") + print(f" skip_ref_check: {skip_ref_check}") + print(f" use_cold_l2: {use_cold_l2}") + + # Unpack parameters + b, s_q, h_q, d = q_shape + b_, s_k, h_k, d_ = k_shape + + if b != b_: + raise ValueError("q & k must have the same batch size") + + if d != d_: + raise ValueError("q & k must have the same head dimension") + + if d not in {32, 64, 128}: + raise ValueError("head dimension must be 32, 64, or 128") + + if h_q % h_k != 0: + raise ValueError("h_q must be divisible by h_k") + + if isinstance(s_q, tuple) and len(s_q) != b: + raise ValueError("variable_seqlen s_q must have the length of batch size") + if isinstance(s_k, tuple) and len(s_k) != b: + raise ValueError("variable_seqlen s_k must have the length of batch size") + + if in_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: + raise ValueError("in_dtype must be Float8E4M3FN or Float16") + + if out_dtype not in {cutlass.Float8E4M3FN, cutlass.Float16}: + raise ValueError("out_dtype must be Float8E4M3FN or Float16") + + if qk_acc_dtype not in {Float32}: + raise ValueError("qk_acc_dtype must be Float32") + + if pv_acc_dtype not in {Float32}: + raise ValueError("pv_acc_dtype must be Float32") + + if iterations < 1: + raise ValueError("iterations must be at least 1") + + h_r = h_q // h_k + + # Prepare pytorch tensors: Q, K, V (random from 0 to 2) and O (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + def create_cumulative_sequence_lengths(s): + s_cumsum = [0] + for i in range(len(s)): + s_cumsum.append(s_cumsum[-1] + s[i]) + + s_cumsum_cute_tensor, s_cumsum_torch_tensor = cutlass_torch.cute_tensor_like( + torch.tensor(s_cumsum, dtype=torch.int32), + Int32, + is_dynamic_layout=True, + assumed_align=16, + ) + + return s_cumsum_cute_tensor, s_cumsum_torch_tensor + + cum_seqlen_q, cum_seqlen_q_torch = ( + create_cumulative_sequence_lengths(s_q) + if isinstance(s_q, tuple) + else (None, None) + ) + cum_seqlen_k, cum_seqlen_k_torch = ( + create_cumulative_sequence_lengths(s_k) + if isinstance(s_k, tuple) + else (None, None) + ) + + def create_and_pad_tensor( + shape, padding, dtype, s_cumsum=None, is_dynamic_layout=True + ): + # (b, s, h, d) + shape_ = tuple(map(lambda x, y: x + y, shape, padding)) + if s_cumsum is not None: + if shape_[0] != 1 or padding[0] != 0: + raise ValueError("Invalid tensor creation for variable sequence length") + # (s_total + padding, h, d) + shape_ = shape_[1:] + padding = padding[1:] + + # Create f32 torch tensor (cpu) + f32_torch_tensor_full = cutlass_torch.create_and_permute_torch_tensor( + shape_, + torch.float32, + permute_order=None, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=-2 if dtype.is_float or dtype.signed else 0, max_val=2 + ), + ) + # Create dtype cute & torch tensor (gpu) + _, torch_tensor_full = cutlass_torch.cute_tensor_like( + f32_torch_tensor_full, + dtype, + is_dynamic_layout, + assumed_align=16, + ) + + # Offset the tensor + slices = tuple(slice(s, e) for s, e in zip(padding, shape_)) + torch_tensor = torch_tensor_full[slices].detach() + f32_torch_tensor = f32_torch_tensor_full[slices].detach() + torch_tensor._keep_alive = torch_tensor_full + f32_torch_tensor._keep_alive = f32_torch_tensor_full + + # Create dtype cute tensor with offset (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + + # From ragged to jagged + if s_cumsum is not None: + torch_tensor = torch.nested.nested_tensor_from_jagged( + values=torch_tensor, offsets=s_cumsum + ) + f32_torch_tensor = torch.nested.nested_tensor_from_jagged( + values=f32_torch_tensor, offsets=s_cumsum.cpu() + ) + + return ( + f32_torch_tensor, + cute_tensor, + torch_tensor, + ) + + qo_shape = (b, s_q, h_r * h_k, d) + kv_shape = (b, s_k, h_k, d) + qo_padding = (0, 0, 0, 0, 0) + kv_padding = (0, 0, 0, 0, 0) + + if isinstance(s_q, tuple): + qo_shape = (1, sum(s_q), h_r * h_k, d) + qo_padding = (0, max(s_q), 0, 0, 0) + + if isinstance(s_k, tuple): + kv_shape = (1, sum(s_k), h_k, d) + kv_padding = (0, max(s_k), 0, 0, 0) + + q_ref, q_tensor, q_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + k_ref, k_tensor, k_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + v_ref, v_tensor, v_torch = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, o_tensor, o_torch = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + + mma_tiler = (*mma_tiler_mn, d) + + mask_type = MaskType.NO_MASK + if is_causal: + mask_type = MaskType.CAUSAL_MASK + else: + if isinstance(s_k, tuple): + for i in range(len(s_k)): + if s_k[i] % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK + else: + if s_k % mma_tiler_mn[1] != 0: + mask_type = MaskType.RESIDUAL_MASK + + fmha = BlackwellFusedMultiHeadAttentionForward( + qk_acc_dtype, + pv_acc_dtype, + mma_tiler, + is_persistent, + mask_type, + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + if scale_softmax == 0.0: # default to 1/sqrt(d) + scale_softmax = 1.0 / math.sqrt(d) + log2_e = math.log2( + math.exp(1.0) + ) # gpu uses exp2 for perf concerns, we need an extra factor 'log2_e' here + + scale_softmax = scale_q * scale_k * scale_softmax + scale_softmax_log2 = scale_softmax * log2_e + scale_output = scale_v * inv_scale_o + + problem_size = ( + b, + max(s_q) if isinstance(s_q, tuple) else s_q, + max(s_k) if isinstance(s_k, tuple) else s_k, + h_q, + h_k, + d, + ) + + print("Compiling kernel with cute.compile ...") + start_time = time.time() + # compile fmha kernel + compiled_fmha = cute.compile( + fmha, + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + compilation_time = time.time() - start_time + print(f"Compilation time: {compilation_time:.4f} seconds") + + def run_torch_fmha(q, k, v, scale_softmax=1.0, scale_output=1.0, is_causal=False): + h_q = q.shape[2] + h_k = k.shape[2] + + # as we initialize q, k, v with shape (b, s, h, d) and SDPA of torch needs them to be (b, h, s, d) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # For the situation that torch has not supported, we need to handle it manually + situation1 = (h_q != h_k or is_causal) and (q.is_nested or k.is_nested) + situation2 = (q.is_nested and not k.is_nested) or ( + not q.is_nested and k.is_nested + ) + if situation1 or situation2: + # Once torch supports the situation, we can remove this fallback + batch_size = q.size(0) + ref_list = [] + for batch_idx in range(batch_size): + q_i = q[batch_idx] + k_i = k[batch_idx] + v_i = v[batch_idx] + + ref_i = F.scaled_dot_product_attention( + q_i, + k_i, + v_i, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=is_causal, + enable_gqa=(h_q != h_k), + ) + ref_i = ref_i.transpose(0, 1) * scale_output + ref_list.append(ref_i) + if q.is_nested: + ref = torch.nested.nested_tensor(ref_list, layout=torch.jagged) + else: + ref = torch.stack(ref_list) + else: + ref = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=None, + dropout_p=0.0, + scale=scale_softmax, + is_causal=is_causal, + enable_gqa=(h_q != h_k), + ) + ref = ref.transpose(1, 2) * scale_output + return ref + + if not skip_ref_check: + # Execute kernel once for reference checking + compiled_fmha( + q_tensor.iterator, + k_tensor.iterator, + v_tensor.iterator, + o_tensor.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + print("Verifying results...") + o_ref = run_torch_fmha( + q_ref, k_ref, v_ref, scale_softmax, scale_output, is_causal + ) + + if o_ref.is_nested: + o_ref = o_ref.values() + + if o_torch.is_nested: + o_torch = o_torch.values() + + # convert o back to f32 for comparison + o_fp32, o_fp32_torch = cutlass_torch.cute_tensor_like( + torch.empty(*o_torch.shape, dtype=torch.float32), + Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + cute.testing.convert(o_tensor, o_fp32) + o_result = o_fp32_torch.cpu() + + if out_dtype.is_float and out_dtype.width <= 8: + ref_narrow_precision, _ = cutlass_torch.cute_tensor_like( + torch.empty(*o_ref.shape, dtype=torch.uint8), + out_dtype, + is_dynamic_layout=True, + assumed_align=16, + ) + + ref_o_f32, ref_o_f32_torch = cutlass_torch.cute_tensor_like( + o_ref, + cutlass.Float32, + is_dynamic_layout=True, + assumed_align=16, + ) + + # convert ref : f32 -> fp4/fp8 -> f32 + cute.testing.convert(ref_o_f32, ref_narrow_precision) + cute.testing.convert(ref_narrow_precision, ref_o_f32) + + o_ref = ref_o_f32_torch.cpu() + + # override tolerance + tolerance = 0.13 + + # Assert close results + torch.testing.assert_close(o_result, o_ref, atol=tolerance, rtol=1e-05) + print("Results verified successfully!") + + def generate_tensors(): + _, q_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + in_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + _, k_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, v_tensor_workspace, _ = create_and_pad_tensor( + kv_shape, + kv_padding, + in_dtype, + s_cumsum=cum_seqlen_k_torch, + is_dynamic_layout=True, + ) + _, o_tensor_workspace, _ = create_and_pad_tensor( + qo_shape, + qo_padding, + out_dtype, + s_cumsum=cum_seqlen_q_torch, + is_dynamic_layout=True, + ) + return testing.JitArguments( + q_tensor_workspace.iterator, + k_tensor_workspace.iterator, + v_tensor_workspace.iterator, + o_tensor_workspace.iterator, + problem_size, + cum_seqlen_q, + cum_seqlen_k, + scale_softmax_log2, + scale_output, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + q_torch_effective = q_torch.values() if q_torch.is_nested else q_torch + k_torch_effective = k_torch.values() if k_torch.is_nested else k_torch + v_torch_effective = v_torch.values() if v_torch.is_nested else v_torch + o_torch_effective = o_torch.values() if o_torch.is_nested else o_torch + one_workspace_bytes = ( + q_torch_effective.numel() * q_torch_effective.element_size() + + k_torch_effective.numel() * k_torch_effective.element_size() + + v_torch_effective.numel() * v_torch_effective.element_size() + + o_torch_effective.numel() * o_torch_effective.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_fmha, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str): + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_nested_comma_separated_ints(s: str): + try: + s = s.strip() + if "(" not in s: + return tuple(int(x.strip()) for x in s.split(",")) + + start = s.find("(") + end = s.find(")") + if start == -1 or end == -1: + raise ValueError("Mismatched parentheses") + + before = s[:start].strip().rstrip(",") + middle = s[start + 1 : end].strip() + after = s[end + 1 :].strip().lstrip(",") + + result = [] + if before: + result.extend(int(x.strip()) for x in before.split(",")) + + if middle: + nested_tuple = tuple(int(x.strip()) for x in middle.split(",")) + result.append(nested_tuple) + + if after: + result.extend(int(x.strip()) for x in after.split(",")) + + return tuple(result) + + except ValueError as e: + if str(e) == "Mismatched parentheses": + raise argparse.ArgumentTypeError("Mismatched parentheses in input") + else: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers with optional parentheses for nested tuple." + ) + + parser = argparse.ArgumentParser(description="Example of FMHA on Blackwell.") + + parser.add_argument( + "--in_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Input data type", + ) + + parser.add_argument( + "--out_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + help="Output data type", + ) + + parser.add_argument( + "--qk_acc_dtype", + type=cutlass.dtype, + default=Float32, + help="QK accumulator data type", + ) + + parser.add_argument( + "--pv_acc_dtype", + type=cutlass.dtype, + default=Float32, + help="PV accumulator data type", + ) + + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="MMA tile shape (M, N)", + ) + + parser.add_argument( + "--is_persistent", + action="store_true", + help="Is persistent", + ) + + parser.add_argument( + "--is_causal", + action="store_true", + help="Whether to use casual mask", + ) + + parser.add_argument( + "--q_shape", + type=parse_nested_comma_separated_ints, + default=(1, 256, 8, 128), + help="Shape of Q (B, S_q, H, D)", + ) + + parser.add_argument( + "--k_shape", + type=parse_nested_comma_separated_ints, + default=(1, 256, 8, 128), + help="Shape of K (B, S_k, H_k, D)", + ) + + parser.add_argument( + "--scale_q", + type=float, + default=1.0, + help="Scaling factors to dequantize Q", + ) + + parser.add_argument( + "--scale_k", + type=float, + default=1.0, + help="Scaling factors to dequantize K", + ) + + parser.add_argument( + "--scale_v", + type=float, + default=1.0, + help="Scaling factors to dequantize V", + ) + + parser.add_argument( + "--inv_scale_o", + type=float, + default=1.0, + help="Scaling factor to quantize O", + ) + + parser.add_argument( + "--scale_softmax", + type=float, + default=0.0, + help="Scaling factor to scale S (i.e. Q*K); if zero, defaults to 1/sqrt(D)", + ) + + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of iterations for warmup", + ) + + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations after warmup", + ) + + parser.add_argument( + "--skip_ref_check", + action="store_true", + help="Skip reference check", + ) + + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.q_shape) != 4: + parser.error("--q_shape must contain exactly 4 values") + + if len(args.k_shape) != 4: + parser.error("--k_shape must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + run( + args.q_shape, + args.k_shape, + args.in_dtype, + args.out_dtype, + args.qk_acc_dtype, + args.pv_acc_dtype, + args.mma_tiler_mn, + args.is_persistent, + args.is_causal, + args.scale_q, + args.scale_k, + args.scale_v, + args.inv_scale_o, + args.scale_softmax, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..c279f7588fa697c516c5acae340c9fa735437ba7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py @@ -0,0 +1,2445 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import functools +from typing import List, Type, Union +from inspect import isclass + +import torch +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.torch as cutlass_torch + +""" +A grouped GEMM example for the NVIDIA Blackwell SM100 architecture using CUTE DSL + +This example demonstrates an implementation of grouped GEMM using a TMA plus Blackwell SM100 TensorCore +warp-specialized persistent kernel. +The grouped GEMM workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices +in global memory are passed to the kernel in an array (also held in global memory). Similarly, problem shapes and +strides are also stored in arrays in GMEM. + +This differs from "Batched Array" GEMM since the size of each GEMM problem in the grouped GEMM concept may be distinct. + +To run this example: + +.. code-block:: bash + + python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM + +The above example command makes 4 groups of different m, n, k sizes. The Blackwell tcgen05 MMA tile shape +is specified as (128, 64) and the cluster shape is (1,1). The input, mma accumulator and output data type +are set as fp16, fp32 and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/grouped_gemm.py \ + --ab_dtype Float16 --c_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 128,64 --cluster_shape_mn 1,1 \ + --problem_sizes_mnkl "(8192,1280,32,1),(16,384,1536,1),(640,1280,16,1),(640,160,16,1)" \ + --num_groups 4 --tensormap_update_mode SMEM \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + +There are some constrains for this example. Besides the constrains from the Balckwell dense GEMM persistent example, +there are also the following constrains: +* Only fp16 and bf16 data types are supported as inputs. +* Output data types could be fp16, bf16 or fp32. +* The contiguous dimension of each tensor must be at least 16 bytes aligned. +* The l mode(aka, batch size) for each group must be 1. +* The majorness for A, B and C must be the same across all groups. +""" + + +class GroupedGemmKernel: + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + tensormap_update_mode: utils.TensorMapUpdateMode = utils.TensorMapUpdateMode.SMEM, + ): + """Initializes the configuration for a Blackwell grouped GEMM kernel. + + Besides configurations for dense persistent GEMM, there is an extra config specific to grouped GEMM: + + Tensormap Update Mode: + - tensormap_update_mode: Specifies whether the tensormap is + updated in global memory(GMEM) or shared memory(SMEM). + The 2 modes are functionally equivalent and the difference are: + - We buffer 3 tensormaps in SMEM for A, B, and C tensors (each TMA descriptor takes 128B) when TMA updates performed on SMEM. + - Performance varies between modes depending on problem size; optimal choice differs across workloads. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param mma_tiler_mn: tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: tuple[int, int] + :param cluster_shape_mn: tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: tuple[int, int] + :param tensormap_update_mode: Mode for updating the tensormap (GMEM or SMEM), defaults to SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode, optional + """ + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.cta_group = ( + tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + self.tensormap_update_mode = tensormap_update_mode + # Delegate tensormap ab initialization to MMA warp when SMEM mode is used for better latency hiding + self.delegate_tensormap_ab_init = ( + tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ) + + self.num_mcast_ctas_a = 1 + self.num_mcast_ctas_b = 1 + self.is_a_mcast = False + self.is_b_mcast = False + + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = ( + 0, + 1, + 2, + 3, + ) + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.threads_per_cta = 32 * len( + (self.mma_warp_id, self.tma_warp_id, *self.epilog_warp_id) + ) + # Set barrier id for cta sync, epilog sync, tmem ptr sync and tensormap update sync + self.cta_sync_bar_id = 0 + self.epilog_sync_bar_id = 1 + self.tmem_ptr_sync_bar_id = 2 + # Barrier ID used by MMA/TMA warps to signal A/B tensormap initialization completion + self.tensormap_ab_init_bar_id = 4 + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + self.num_tma_load_bytes = 0 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + Most of the implementation follows standard dense GEMM patterns, + with the key difference being additional consideration for SMEM + buffer needed for tensormap updates. + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + self.cluster_tile_shape_mnk = tuple( + x * y for x, y in zip(self.cta_tile_shape_mnk, (*self.cluster_shape_mn, 1)) + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_epi_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.smem_capacity, + self.occupancy, + ) + + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_epi_stage, + ) + + tensor_smem_bytes = self._get_tensor_smem_bytes( + self.a_smem_layout_staged, + self.a_dtype, + self.b_smem_layout_staged, + self.b_dtype, + self.epi_smem_layout_staged, + self.c_dtype, + ) + mbar_smem_bytes = self._get_mbar_smem_bytes( + num_acc_stage=self.num_acc_stage, + num_ab_stage=self.num_ab_stage, + num_epi_stage=self.num_epi_stage, + ) + tensormap_smem_bytes = self._get_tensormap_smem_bytes( + self.tensormap_update_mode + ) + if ( + mbar_smem_bytes + + tensormap_smem_bytes + + GroupedGemmKernel.tensor_memory_management_bytes + > self.reserved_smem_bytes + ): + raise ValueError( + f"smem consumption for mbar and tensormap {mbar_smem_bytes + tensormap_smem_bytes} exceeds the " + f"reserved smem bytes {self.reserved_smem_bytes}" + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + tiled_mma, self.mma_tiler, self.num_acc_stage + ) + + @cute.jit + def __call__( + self, + initial_a: cute.Tensor, + initial_b: cute.Tensor, + initial_c: cute.Tensor, + group_count: cutlass.Constexpr[int], + problem_shape_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + total_num_clusters: cutlass.Constexpr[int], + tensormap_cute_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr[int], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + For grouped GEMM, tensor shapes, tensor strides, and tensor address are all provided + by different tensors in global memory. The "initial" tensors only carry data type and + majorness information. + + :param initial_a: Initial tensor A, used for data type and majorness information. + :type initial_a: cute.Tensor + :param initial_b: Initial tensor B, used for data type and majorness information. + :type initial_b: cute.Tensor + :param initial_c: Initial tensor C, used for data type and majorness information. + :type initial_c: cute.Tensor + :param group_count: The number of GEMM groups. + :type group_count: cutlass.Constexpr[int] + :param problem_shape_mnkl: Tensor containing the (M, N, K, L) shape for each group. + :type problem_shape_mnkl: cute.Tensor + :param strides_abc: Tensor containing the strides for A, B, and C for each group. + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing the base addresses for A, B, and C for each group. + :type tensor_address_abc: cute.Tensor + :param total_num_clusters: Total number of clusters needed for all groups. + :type total_num_clusters: cutlass.Constexpr[int] + :param tensormap_cute_tensor: Tensor for storing tensormaps. + :type tensormap_cute_tensor: cute.Tensor + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + :param stream: CUDA stream for asynchronous execution. + :type stream: cuda.CUstream + :raises TypeError: If A and B data types do not match. + """ + self.a_dtype = initial_a.element_type + self.b_dtype = initial_b.element_type + self.c_dtype = initial_c.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(initial_a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(initial_b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(initial_c) + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mn, tiled_mma.thr_id + ) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + initial_a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + # Setup TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mn, tiled_mma.thr_id + ) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + initial_b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition( + cute.make_identity_layout(initial_c.shape), self.epi_tile + ) + epi_smem_layout = cute.slice_(self.epi_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + initial_c, + epi_smem_layout, + c_cta_v_layout, + ) + + self.tile_sched_params, grid = self._compute_grid( + total_num_clusters, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + self.size_tensormap_in_i64 = ( + 0 + if self.tensormap_update_mode == utils.TensorMapUpdateMode.GMEM + else GroupedGemmKernel.num_tensormaps + * GroupedGemmKernel.bytes_per_tensormap + // 8 + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + tensormap_buffer: cute.struct.MemRange[ + cutlass.Int64, self.size_tensormap_in_i64 + ] + ab_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + ab_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage] + acc_full_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + acc_empty_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + cute.cosize(self.epi_smem_layout_staged.outer), + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + group_count, + problem_shape_mnkl, + strides_abc, + tensor_address_abc, + tensormap_cute_tensor, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout], + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + group_count: cutlass.Constexpr[int], + problem_sizes_mnkl: cute.Tensor, + strides_abc: cute.Tensor, + ptrs_abc: cute.Tensor, + tensormaps: cute.Tensor, + ): + """ + GPU device kernel performing the grouped GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coord inside cluster + bid = cute.arch.block_idx() + mma_tile_coord_v = bid[0] % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: tensormap buffer, a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tensormap_a_smem_ptr = None + tensormap_b_smem_ptr = None + tensormap_c_smem_ptr = None + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_smem_ptr = storage.tensormap_buffer.data_ptr() + tensormap_a_smem_ptr = tensormap_smem_ptr + tensormap_b_smem_ptr = ( + tensormap_a_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + tensormap_c_smem_ptr = ( + tensormap_b_smem_ptr + GroupedGemmKernel.bytes_per_tensormap // 8 + ) + ab_full_mbar_ptr = storage.ab_full_mbar_ptr.data_ptr() + ab_empty_mbar_ptr = storage.ab_empty_mbar_ptr.data_ptr() + acc_full_mbar_ptr = storage.acc_full_mbar_ptr.data_ptr() + acc_empty_mbar_ptr = storage.acc_empty_mbar_ptr.data_ptr() + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # init barrier for loading A, B with TMA + if warp_idx == self.epilog_warp_id[0]: + for k_stage in range(self.num_ab_stage): + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + with cute.arch.elect_one(): + cute.arch.mbarrier_init(ab_full_mbar_ptr + k_stage, 1) + cute.arch.mbarrier_init( + ab_empty_mbar_ptr + k_stage, num_tma_producer + ) + # Accumulator barrier init + if warp_idx == self.mma_warp_id: + for acc_stage in range(self.num_acc_stage): + with cute.arch.elect_one(): + cute.arch.mbarrier_init(acc_full_mbar_ptr + acc_stage, 1) + cute.arch.mbarrier_init( + acc_empty_mbar_ptr + acc_stage, 8 if use_2cta_instrs else 4 + ) + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.tma_warp_id: + num_tmem_dealloc_threads = 32 + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads + ) + cute.arch.mbarrier_init_fence() + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # + # Setup smem tensor A/B/C + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor( + epi_smem_layout_staged.outer, swizzle=epi_smem_layout_staged.inner + ) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + + # + # Compute multicast mask for A/B buffer full and empty + # + a_full_mcast_mask = None + b_full_mcast_mask = None + ab_empty_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + ab_empty_mcast_mask = a_full_mcast_mask | b_full_mcast_mask + acc_full_mcast_mask = None + if cutlass.const_expr(use_2cta_instrs): + acc_full_mcast_mask = cute.make_layout_image_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mode=0 + ) + block_in_cluster_coord_vmnk_peer = ( + block_in_cluster_coord_vmnk[0] ^ 1, + *block_in_cluster_coord_vmnk[1:], + ) + a_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=2 + ) + b_full_mcast_mask_peer = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk_peer, mcast_mode=1 + ) + ab_empty_mcast_mask = ( + a_full_mcast_mask_peer + | b_full_mcast_mask_peer + | cutlass.Int16( + 0 if ab_empty_mcast_mask is None else ab_empty_mcast_mask + ) + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, RestM, RestK, RestL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, RestN, RestK, RestL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, RestM, RestN, RestL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, RestM, RestK, RestL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, RestN, RestK, RestL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, RestM, RestN, RestL) + tCgC = thr_mma.partition_C(gC_mnl) + + # + # Partition global/shared tensor for load A, B with TMA + # + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), RestM, RestK, RestL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C( + cute.append(acc_shape, self.num_acc_stage) + ) + + # + # Cluster wait before tensor memory alloc + # + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.barrier( + barrier_id=self.cta_sync_bar_id, number_of_threads=self.threads_per_cta + ) + + # + # Get tensormap buffer address + # + grid_dim = cute.arch.grid_dim() + tensormap_workspace_idx = ( + bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] + ) + + tensormap_manager = utils.TensorMapManager( + self.tensormap_update_mode, GroupedGemmKernel.bytes_per_tensormap + ) + tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 0, None)].iterator + ) + tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 1, None)].iterator + ) + tensormap_c_ptr = tensormap_manager.get_tensormap_ptr( + tensormaps[(tensormap_workspace_idx, 2, None)].iterator + ) + # Setup tensormap initialization pointer based on the mode + if cutlass.const_expr( + self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM + ): + tensormap_a_init_ptr = tensormap_a_smem_ptr + tensormap_b_init_ptr = tensormap_b_smem_ptr + tensormap_c_init_ptr = tensormap_c_smem_ptr + else: + tensormap_a_init_ptr = tensormap_a_ptr + tensormap_b_init_ptr = tensormap_b_ptr + tensormap_c_init_ptr = tensormap_c_ptr + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + # Initialize tensormaps for A, B + if cutlass.const_expr(self.delegate_tensormap_ab_init == False): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.tma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.tma_warp_id + ) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + tensormap_init_done = cutlass.Boolean(False) + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + work_tile = tile_sched.initial_work_tile_info() + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + # skip tensormap update if we're working on the same group + if is_group_changed: + real_tensor_a = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.a_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 0, # 0 for tensor A + ) + real_tensor_b = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.b_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 1, # 1 for tensor B + ) + # wait tensormap initialization complete before update + if tensormap_init_done == False: + if cutlass.const_expr(self.delegate_tensormap_ab_init): + cute.arch.barrier( + barrier_id=self.tensormap_ab_init_bar_id, + number_of_threads=64, + ) + tensormap_manager.fence_tensormap_initialization() + tensormap_init_done = True + + tensormap_manager.update_tensormap( + (real_tensor_a, real_tensor_b), + (tma_atom_a, tma_atom_b), + (tensormap_a_ptr, tensormap_b_ptr), + self.tma_warp_id, + (tensormap_a_smem_ptr, tensormap_b_smem_ptr), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), RestK) + tAgA_slice = tAgA[ + (None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2]) + ] + # ((atom_v, rest_v), RestK) + tBgB_slice = tBgB[ + (None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2]) + ] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + tma_wr_k_tile = cutlass.Int32(0) + smem_wr_buffer = (num_prev_k_blk + tma_wr_k_tile) % self.num_ab_stage + tma_wr_ab_empty_phase = ( + num_prev_k_blk + tma_wr_k_tile + ) // self.num_ab_stage % 2 ^ 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer, + tma_wr_ab_empty_phase, + ) + # ensure the update to tensormap has completed before using it + if is_group_changed: + tensormap_manager.fence_tensormap_update(tensormap_a_ptr) + tensormap_manager.fence_tensormap_update(tensormap_b_ptr) + # + # Tma load loop + # + for k_tile in cutlass.range(0, cur_k_tile_cnt, 1, unroll=1): + tma_wr_k_tile_next = tma_wr_k_tile + 1 + smem_wr_buffer_next = ( + num_prev_k_blk + tma_wr_k_tile_next + ) % self.num_ab_stage + tma_wr_ab_empty_phase_next = ( + tma_wr_ab_empty_phase ^ 1 + if smem_wr_buffer_next == 0 + else tma_wr_ab_empty_phase + ) + + smem_full_mbar_ptr = ab_full_mbar_ptr + smem_wr_buffer + + # Wait for AB buffer empty + if peek_ab_empty_status == 0: + cute.arch.mbarrier_wait( + ab_empty_mbar_ptr + smem_wr_buffer, tma_wr_ab_empty_phase + ) + + # Arrive AB buffer and expect full transaction bytes + if is_leader_cta: + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive_and_expect_tx( + smem_full_mbar_ptr, self.num_tma_load_bytes + ) + + # Load A/B with TMA + cute.copy( + tma_atom_a, + tAgA_slice[(None, tma_wr_k_tile)], + tAsA[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=a_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_a_ptr, + cute.AddressSpace.generic, + ), + ) + cute.copy( + tma_atom_b, + tBgB_slice[(None, tma_wr_k_tile)], + tBsB[(None, smem_wr_buffer)], + tma_bar_ptr=smem_full_mbar_ptr, + mcast_mask=b_full_mcast_mask, + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_b_ptr, + cute.AddressSpace.generic, + ), + ) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + peek_ab_empty_status = cute.arch.mbarrier_conditional_try_wait( + tma_wr_k_tile_next < cur_k_tile_cnt, + ab_empty_mbar_ptr + smem_wr_buffer_next, + tma_wr_ab_empty_phase_next, + ) + + tma_wr_k_tile = tma_wr_k_tile_next + smem_wr_buffer = smem_wr_buffer_next + tma_wr_ab_empty_phase = tma_wr_ab_empty_phase_next + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + # initialize tensormap A, B for TMA warp + if cutlass.const_expr(self.delegate_tensormap_ab_init): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, tensormap_a_init_ptr, self.mma_warp_id + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, tensormap_b_init_ptr, self.mma_warp_id + ) + # signal tensormap initialization has finished + cute.arch.barrier( + barrier_id=self.tensormap_ab_init_bar_id, number_of_threads=64 + ) + # Bar sync for retrieve tmem ptr from shared mem + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + # MMA warp is only interested in number of tiles along K dimension + ( + cur_k_tile_cnt, + cur_group_idx, + ) = group_gemm_ts_helper.search_cluster_tile_count_k( + cur_tile_coord, + problem_sizes_mnkl, + ) + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_buf_idx)] + + num_prev_k_blk = total_k_tile_cnt + total_k_tile_cnt += cur_k_tile_cnt + + # Peek (try_wait) AB buffer full for k_tile = 0 + mma_rd_k_tile = cutlass.Int32(0) + smem_rd_buffer = (num_prev_k_blk + mma_rd_k_tile) % self.num_ab_stage + need_check_rd_buffer_full = ( + mma_rd_k_tile < cur_k_tile_cnt and is_leader_cta + ) + mma_rd_ab_full_phase = ( + (num_prev_k_blk + mma_rd_k_tile) // self.num_ab_stage % 2 + ) + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer, + mma_rd_ab_full_phase, + ) + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_empty_phase = ( + tile_sched.num_tiles_executed // self.num_acc_stage % 2 ^ 1 + ) + cute.arch.mbarrier_wait( + acc_empty_mbar_ptr + acc_buf_idx, acc_empty_phase + ) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + # + # Mma mainloop + # + for k_tile in range(cur_k_tile_cnt): + mma_rd_k_tile_next = cutlass.Int32(k_tile + 1) + smem_rd_buffer_next = ( + num_prev_k_blk + mma_rd_k_tile_next + ) % self.num_ab_stage + mma_rd_ab_full_phase_next = ( + mma_rd_ab_full_phase ^ 1 + if smem_rd_buffer_next == 0 + else mma_rd_ab_full_phase + ) + if is_leader_cta: + # Wait for AB buffer full + if peek_ab_full_status == 0: + cute.arch.mbarrier_wait( + ab_full_mbar_ptr + smem_rd_buffer, mma_rd_ab_full_phase + ) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = (None, None, kblock_idx, smem_rd_buffer) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + with cute.arch.elect_one(): + tcgen05.commit( + ab_empty_mbar_ptr + smem_rd_buffer, + ab_empty_mcast_mask, + self.cta_group, + ) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + need_check_rd_buffer_full = ( + mma_rd_k_tile_next < cur_k_tile_cnt and is_leader_cta + ) + + peek_ab_full_status = cute.arch.mbarrier_conditional_try_wait( + need_check_rd_buffer_full, + ab_full_mbar_ptr + smem_rd_buffer_next, + mma_rd_ab_full_phase_next, + ) + + mma_rd_k_tile = mma_rd_k_tile_next + smem_rd_buffer = smem_rd_buffer_next + mma_rd_ab_full_phase = mma_rd_ab_full_phase_next + + # + # Async arrive accumulator buffer full + # + if is_leader_cta: + with cute.arch.elect_one(): + tcgen05.commit( + acc_full_mbar_ptr + acc_buf_idx, + acc_full_mcast_mask, + self.cta_group, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # initialize tensorap for C + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_init_ptr, + self.epilog_warp_id[0], + ) + # Alloc tensor memory buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, + tmem_holding_buf, + is_two_cta=use_2cta_instrs, + ) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem_ptr_read_threads = 32 * len((self.mma_warp_id, *self.epilog_warp_id)) + cute.arch.barrier( + barrier_id=self.tmem_ptr_sync_bar_id, + number_of_threads=tmem_ptr_read_threads, + ) + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=tmem_holding_buf, + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + epi_tidx = tidx + # + # Partition for epilogue + # + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = cute.make_fragment(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, bid, grid_dim + ) + # grouped gemm tile scheduler helper will compute the group index for the tile we're working on + group_gemm_ts_helper = utils.GroupedGemmTileSchedulerHelper( + group_count, + tile_sched_params, + self.cluster_tile_shape_mnk, + utils.create_initial_search_state(), + ) + + work_tile = tile_sched.initial_work_tile_info() + # wait tensormap initialization complete before update + tensormap_manager.fence_tensormap_initialization() + # tile count we have searched + total_k_tile_cnt = cutlass.Int32(0) + # group index of last tile + last_group_idx = cutlass.Int32(-1) + while work_tile.is_valid_tile: + cur_tile_coord = work_tile.tile_idx + grouped_gemm_cta_tile_info = group_gemm_ts_helper.delinearize_z( + cur_tile_coord, + problem_sizes_mnkl, + ) + cur_group_idx = grouped_gemm_cta_tile_info.group_idx + is_group_changed = cur_group_idx != last_group_idx + if is_group_changed: + # construct tensor C based on real address, shape and stride information + real_tensor_c = self.make_tensor_for_tensormap_update( + cur_group_idx, + self.c_dtype, + ( + grouped_gemm_cta_tile_info.problem_shape_m, + grouped_gemm_cta_tile_info.problem_shape_n, + grouped_gemm_cta_tile_info.problem_shape_k, + ), + strides_abc, + ptrs_abc, + 2, # 2 for tensor C + ) + tensormap_manager.update_tensormap( + ((real_tensor_c),), + ((tma_atom_c),), + ((tensormap_c_ptr),), + self.epilog_warp_id[0], + (tensormap_c_smem_ptr,), + ) + + mma_tile_coord_mnl = ( + grouped_gemm_cta_tile_info.cta_tile_idx_m + // cute.size(tiled_mma.thr_id.shape), + grouped_gemm_cta_tile_info.cta_tile_idx_n, + 0, + ) + cur_k_tile_cnt = grouped_gemm_cta_tile_info.cta_tile_count_k + total_k_tile_cnt += cur_k_tile_cnt + + # + # Slice to per mma tile index + # + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + *mma_tile_coord_mnl, + ) + ] + + # Set tensor memory buffer for current tile + acc_buf_idx = tile_sched.num_tiles_executed % self.num_acc_stage + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, acc_buf_idx)] + + # + # Wait for accumulator buffer full + # + acc_full_phase = tile_sched.num_tiles_executed // self.num_acc_stage % 2 + cute.arch.mbarrier_wait(acc_full_mbar_ptr + acc_buf_idx, acc_full_phase) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + # ensure the update to tensormap has completed before using it + if is_group_changed: + if warp_idx == self.epilog_warp_id[0]: + tensormap_manager.fence_tensormap_update(tensormap_c_ptr) + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + num_prev_subtiles = tile_sched.num_tiles_executed * subtile_cnt + for subtile_idx in range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to output type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + tRS_rC.store(acc_vec.to(self.c_dtype)) + # + # Store C to shared memory + # + epi_buffer = (num_prev_subtiles + subtile_idx) % self.num_epi_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, epi_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + # + # store C to global memory with TMA + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, epi_buffer)], + bSG_gC[(None, subtile_idx)], + tma_desc_ptr=tensormap_manager.get_tensormap_ptr( + tensormap_c_ptr, + cute.AddressSpace.generic, + ), + ) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group( + self.num_epi_stage - 1, read=True + ) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=epilog_threads, + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive( + acc_empty_mbar_ptr + acc_buf_idx, + cta_rank_in_cluster // 2 * 2 if use_2cta_instrs else None, + ) + + # + # Advance to next tile + # + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + last_group_idx = cur_group_idx + + # + # Dealloc the tensor memory buffer + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilog_threads = 32 * len(self.epilog_warp_id) + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, number_of_threads=epilog_threads + ) + if warp_idx == self.epilog_warp_id[0]: + if use_2cta_instrs: + cute.arch.mbarrier_arrive( + tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1 + ) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # + # Wait a/b buffer empty + # + if warp_idx == self.epilog_warp_id[0]: + cute.arch.mbarrier_wait( + (ab_empty_mbar_ptr + ((total_k_tile_cnt - 1) % self.num_ab_stage)), + (((total_k_tile_cnt - 1) // self.num_ab_stage) % 2), + ) + + @cute.jit + def make_tensor_for_tensormap_update( + self, + group_idx: cutlass.Int32, + dtype: Type[cutlass.Numeric], + problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32], + strides_abc: cute.Tensor, + tensor_address_abc: cute.Tensor, + tensor_index: int, + ): + """Extract stride and tensor address for a given group and construct a global tensor. + + This function is used within the kernel to dynamically create a CUTE tensor + representing A, B, or C for the current group being processed, using the + group-specific address, shape, and stride information. + + :param group_idx: The index of the current group within the grouped GEMM. + :type group_idx: cutlass.Int32 + :param dtype: The data type of the tensor elements (e.g., cutlass.Float16). + :type dtype: Type[cutlass.Numeric] + :param problem_shape_mnk: The (M, N, K) problem shape for the current group. + :type problem_shape_mnk: tuple[cutlass.Int32, cutlass.Int32, cutlass.Int32] + :param strides_abc: Tensor containing strides for A, B, C for all groups. Layout: (group_count, 3, 2). + :type strides_abc: cute.Tensor + :param tensor_address_abc: Tensor containing global memory addresses for A, B, C for all groups. Layout: (group_count, 3). + :type tensor_address_abc: cute.Tensor + :param tensor_index: Specifies which tensor to create: 0 for A, 1 for B, 2 for C. + :type tensor_index: int + :return: A CUTE tensor representing the requested global memory tensor (A, B, or C) for the specified group. + :rtype: cute.Tensor + :raises TypeError: If the provided dtype is not a subclass of cutlass.Numeric. + """ + ptr_i64 = tensor_address_abc[(group_idx, tensor_index)] + if cutlass.const_expr( + not isclass(dtype) or not issubclass(dtype, cutlass.Numeric) + ): + raise TypeError( + f"dtype must be a type of cutlass.Numeric, got {type(dtype)}" + ) + tensor_gmem_ptr = cute.make_ptr( + dtype, ptr_i64, cute.AddressSpace.gmem, assumed_align=16 + ) + + strides_tensor_gmem = strides_abc[(group_idx, tensor_index, None)] + strides_tensor_reg = cute.make_fragment( + cute.make_layout(2), + strides_abc.element_type, + ) + cute.autovec_copy(strides_tensor_gmem, strides_tensor_reg) + stride_mn = strides_tensor_reg[0] + stride_k = strides_tensor_reg[1] + c1 = cutlass.Int32(1) + c0 = cutlass.Int32(0) + + if cutlass.const_expr(tensor_index == 0): # tensor A + m = problem_shape_mnk[0] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, k, c1), stride=(stride_mn, stride_k, c0)), + ) + elif cutlass.const_expr(tensor_index == 1): # tensor B + n = problem_shape_mnk[1] + k = problem_shape_mnk[2] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((n, k, c1), stride=(stride_mn, stride_k, c0)), + ) + else: # tensor C + m = problem_shape_mnk[0] + n = problem_shape_mnk[1] + return cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout((m, n, c1), stride=(stride_mn, stride_k, c0)), + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load(t2r) + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy( + copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_mnl_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, RestM, RestN, RestL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tma_atom_c: cute.CopyAtom, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to partition + shared memory (source) and global memory (destination) for TMA store version. + + :param tma_atom_c: The TMA copy atom configured for storing tensor C. + :type tma_atom_c: cute.CopyAtom + :param gC_mnl: The global memory tensor C. + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler defining the granularity of the operation. + :type epi_tile: cute.Tile + :param sC: The shared memory epilogue buffer tensor. + :type sC: cute.Tensor + :return: A tuple containing: + - tma_atom_c: The input TMA copy atom (passed through). + - bSG_sC: The source shared memory tensor partitioned for the TMA operation. + - tCgC: The destination global memory tensor partitioned for the TMA operation. + :rtype: tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, RestM, RestN, RestL) + gC_epi = cute.flat_divide( + gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile + ) + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, RestM, RestN, RestL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int, int]: + """Computes the number of stages for accumulator, A/B operands, and epilogue based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum of operand C in global memory. + :type c_layout: utils.LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (accumulator stages, A/B operand stages, epilogue stages) + :rtype: tuple[int, int, int] + """ + # Default accumulator and epilogue stages + num_acc_stage = 2 + num_epi_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and Epilogue + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # stage=1 + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # stage=1 + ) + epi_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, # stage=1 + ) + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + + epi_bytes_per_stage = cute.size_in_bytes(c_dtype, epi_smem_layout_staged_one) + epi_bytes = epi_bytes_per_stage * num_epi_stage + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial epilogue bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + smem_capacity // occupancy + - GroupedGemmKernel.reserved_smem_bytes + - epi_bytes + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + remaining_smem = ( + smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (GroupedGemmKernel.reserved_smem_bytes + epi_bytes) + ) + num_epi_stage += remaining_smem // (occupancy * epi_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_epi_stage + + @staticmethod + def _compute_grid( + total_num_clusters: int, + cluster_shape_mn: tuple[int, int], + max_active_clusters: cutlass.Constexpr[int], + ) -> tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]]: + """Compute tile scheduler parameters and grid shape for grouped GEMM operations. + + :param total_num_clusters: Total number of clusters to process across all groups. + :type total_num_clusters: int + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr[int] + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: tuple[utils.PersistentTileSchedulerParams, tuple[int, ...]] + """ + # Create problem shape with M, N dimensions from cluster shape + # and L dimension representing the total number of clusters. + problem_shape_ntile_mnl = ( + cluster_shape_mn[0], + cluster_shape_mn[1], + cutlass.Int32(total_num_clusters), + ) + + tile_sched_params = utils.PersistentTileSchedulerParams( + problem_shape_ntile_mnl, (*cluster_shape_mn, 1) + ) + + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_mbar_smem_bytes(**kwargs_stages: int) -> int: + """Calculate shared memory consumption for memory barriers based on provided stages. + + Each stage requires 2 barriers, and each barrier consumes 8 bytes of shared memory. + The total consumption is the sum across all provided stages. This function calculates the total + shared memory needed for these barriers. + + :param kwargs_stages: Variable keyword arguments where each key is a stage name + (e.g., num_acc_stage, num_ab_stage) and each value is the + number of stages of that type. + :type kwargs_stages: int + :return: Total shared memory bytes required for all memory barriers. + :rtype: int + """ + num_barriers_per_stage = 2 + num_bytes_per_barrier = 8 + mbar_smem_consumption = sum( + [ + num_barriers_per_stage * num_bytes_per_barrier * stage + for stage in kwargs_stages.values() + ] + ) + return mbar_smem_consumption + + @staticmethod + def _get_tensormap_smem_bytes( + tensormap_update_mode: utils.TensorMapUpdateMode, + ) -> int: + """Get the SMEM consumption for the tensormap buffer based on the update mode. + + :param tensormap_update_mode: Specifies whether tensormaps are updated in GMEM or SMEM. + :type tensormap_update_mode: utils.TensorMapUpdateMode + :return: The shared memory bytes required for the tensormap buffer. Returns 0 if mode is GMEM. + :rtype: int + :raises ValueError: If an invalid tensormap update mode is provided. + """ + if tensormap_update_mode == utils.TensorMapUpdateMode.GMEM: + return 0 + elif tensormap_update_mode == utils.TensorMapUpdateMode.SMEM: + return ( + GroupedGemmKernel.bytes_per_tensormap * GroupedGemmKernel.num_tensormaps + ) + else: + raise ValueError(f"Invalid tensormap update mode: {tensormap_update_mode}") + + @staticmethod + def _get_tensor_smem_bytes( + a_smem_layout_staged: cute.Layout, + a_dtype: Type[cutlass.Numeric], + b_smem_layout_staged: cute.Layout, + b_dtype: Type[cutlass.Numeric], + epi_smem_layout_staged: cute.Layout, + c_dtype: Type[cutlass.Numeric], + ) -> int: + """Compute the total SMEM consumption for tensor A, B and C.""" + ab_bytes = cute.size_in_bytes( + a_dtype, a_smem_layout_staged + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged) + + epi_bytes = cute.size_in_bytes(c_dtype, epi_smem_layout_staged) + return ab_bytes + epi_bytes + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param acc_stage: The stage of the accumulator tensor. + :type acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols(tCtAcc_fake) + + return num_tmem_alloc_cols + + # Size of smem we reserved for mbarrier, tensor memory management and tensormap update + reserved_smem_bytes = 1024 + bytes_per_tensormap = 128 + num_tensormaps = 3 + # size of smem used for tensor memory management + tensor_memory_management_bytes = 12 + + +# Create tensor and return the pointer, tensor, and stride +def create_tensor_and_stride( + l: int, + mode0: int, + mode1: int, + is_mode0_major: bool, + dtype: type[cutlass.Numeric], + is_dynamic_layout: bool = True, + torch_tensor_cpu: torch.Tensor = None, +) -> tuple[int, torch.Tensor, cute.Tensor, torch.Tensor, tuple[int, int]]: + """Create a GPU tensor from scratch or based on an existing CPU tensor. + + :param torch_tensor_cpu: Optional existing CPU tensor to reuse. If None, creates a new one. + :type torch_tensor_cpu: torch.Tensor, optional + """ + if torch_tensor_cpu is None: + # Create new CPU tensor + torch_tensor_cpu = cutlass_torch.matrix(l, mode0, mode1, is_mode0_major, dtype) + + # Create GPU tensor from CPU tensor (new or existing) + cute_tensor, torch_tensor = cutlass_torch.cute_tensor_like( + torch_tensor_cpu, dtype, is_dynamic_layout, assumed_align=16 + ) + return ( + torch_tensor.data_ptr(), + torch_tensor, + cute_tensor, + torch_tensor_cpu, + torch_tensor.stride()[:-1], + ) + + +def create_tensors_for_all_groups( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + torch_fp32_tensors_abc: List[List[torch.Tensor]] = None, +) -> tuple[ + List[List[int]], + List[List[torch.Tensor]], + List[tuple], + List[List[tuple]], + List[List[torch.Tensor]], +]: + if torch_fp32_tensors_abc is not None and len(torch_fp32_tensors_abc) != len( + problem_sizes_mnkl + ): + raise ValueError("torch_fp32_tensors_abc must have one entry per group") + + # Initialize lists to store tensors for all groups + new_torch_fp32_tensors_abc = ( + [] if torch_fp32_tensors_abc is None else torch_fp32_tensors_abc + ) + torch_tensors_abc = [] + cute_tensors_abc = [] + strides_abc = [] + ptrs_abc = [] + + # Iterate through all groups and create tensors for each group + for group_idx, (m, n, k, l) in enumerate(problem_sizes_mnkl): + # Get existing CPU tensors if available, otherwise None + existing_cpu_a = ( + torch_fp32_tensors_abc[group_idx][0] if torch_fp32_tensors_abc else None + ) + existing_cpu_b = ( + torch_fp32_tensors_abc[group_idx][1] if torch_fp32_tensors_abc else None + ) + existing_cpu_c = ( + torch_fp32_tensors_abc[group_idx][2] if torch_fp32_tensors_abc else None + ) + + # Create tensors (reusing CPU tensors if provided) + ( + ptr_a, + torch_tensor_a, + cute_tensor_a, + tensor_fp32_a, + stride_mk_a, + ) = create_tensor_and_stride( + l, m, k, a_major == "m", ab_dtype, torch_tensor_cpu=existing_cpu_a + ) + ( + ptr_b, + torch_tensor_b, + cute_tensor_b, + tensor_fp32_b, + stride_nk_b, + ) = create_tensor_and_stride( + l, n, k, b_major == "n", ab_dtype, torch_tensor_cpu=existing_cpu_b + ) + ( + ptr_c, + torch_tensor_c, + cute_tensor_c, + tensor_fp32_c, + stride_mn_c, + ) = create_tensor_and_stride( + l, m, n, c_major == "m", c_dtype, torch_tensor_cpu=existing_cpu_c + ) + + # Only append to new_torch_fp32_tensors_abc if we created new CPU tensors + if torch_fp32_tensors_abc is None: + new_torch_fp32_tensors_abc.append( + [tensor_fp32_a, tensor_fp32_b, tensor_fp32_c] + ) + + ptrs_abc.append([ptr_a, ptr_b, ptr_c]) + torch_tensors_abc.append([torch_tensor_a, torch_tensor_b, torch_tensor_c]) + strides_abc.append([stride_mk_a, stride_nk_b, stride_mn_c]) + cute_tensors_abc.append( + ( + cute_tensor_a, + cute_tensor_b, + cute_tensor_c, + ) + ) + + return ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + new_torch_fp32_tensors_abc, + ) + + +def run( + num_groups: int, + problem_sizes_mnkl: tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + tensormap_update_mode: utils.TensorMapUpdateMode, + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """Run grouped GEMM example with specified configurations. + + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + print(f"Running Blackwell Grouped GEMM test with:") + print(f"{num_groups} groups") + for i, (m, n, k, l) in enumerate(problem_sizes_mnkl): + print(f"Group {i}: {m}x{n}x{k}x{l}") + print(f"AB dtype: {ab_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"2CTA MMA instructions: {'True' if use_2cta_instrs else 'False'}") + print(f"Tensor map update mode: {tensormap_update_mode}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Skip unsupported types + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + raise ValueError(f"Skip unsupported ab_dtype {ab_dtype}") + if c_dtype not in {cutlass.Float16, cutlass.BFloat16, cutlass.Float32}: + raise ValueError(f"Skip unsupported c_dtype {c_dtype}") + # Skip unsupported acc dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + raise ValueError(f"Skip unsupported acc_dtype {acc_dtype}") + # Skip invalid ab_dtype and acc_dtype combination + if ab_dtype == cutlass.BFloat16 and acc_dtype == cutlass.Float16: + raise ValueError("Skip invalid ab_dtype and acc_dtype combination") + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + raise ValueError(f"Skip invalid mma tiler M {mma_tiler_mn[0]}") + if mma_tiler_mn[1] not in range(32, 257, 32): + raise ValueError(f"Skip invalid mma tiler N {mma_tiler_mn[1]}") + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + raise ValueError( + f"cluster_shape_m need align with use_2cta_instrs config {cluster_shape_mn}" + ) + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + raise ValueError(f"Skip invalid cluster shape {cluster_shape_mn}") + + # Skip illegal problem shape for load/store alignment + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, l)) + ): + raise ValueError("Skip invalid problem alignment") + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Create tensors for all groups using the new function + ( + ptrs_abc, + torch_tensors_abc, + cute_tensors_abc, + strides_abc, + torch_fp32_tensors_abc, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + ) + + # Choose A, B, C with the smallest size to create initial tensormaps + key_size_a = lambda item: item[1][0] * item[1][2] + key_size_b = lambda item: item[1][1] * item[1][2] + key_size_c = lambda item: item[1][0] * item[1][1] + # Find the indices of the groups with the smallest tensor sizes + min_a_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_a) + min_b_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_b) + min_c_idx, _ = min(enumerate(problem_sizes_mnkl), key=key_size_c) + initial_cute_tensors_abc = [ + cute_tensors_abc[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc[min_c_idx][2], # C with smallest (m, n) + ] + + hardware_info = utils.HardwareInfo() + sm_count = hardware_info.get_max_active_clusters(1) + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + # Prepare tensormap buffer for each SM + num_tensormap_buffers = sm_count + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensor_of_tensormap, tensor_of_tensormap_torch = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + grouped_gemm = GroupedGemmKernel( + acc_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + tensormap_update_mode, + ) + + # layout (num_groups, 4):(4, 1) + ( + tensor_of_dim_size_mnkl, + tensor_of_dim_size_mnkl_torch, + ) = cutlass_torch.cute_tensor_like( + torch.tensor(problem_sizes_mnkl, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + # layout (num_groups, 3, 2):(6, 2, 1) + tensor_of_strides_abc, tensor_of_strides_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + # layout (num_groups,3):(3, 1) + tensor_of_ptrs_abc, tensor_of_ptrs_abc_torch = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + # Compute total number of cluster tiles we need to compute for given grouped GEMM problem + def compute_total_num_clusters( + problem_sizes_mnkl: List[tuple[int, int, int, int]], + cluster_tile_shape_mn: tuple[int, int], + ) -> int: + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + (x + y - 1) // y for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + use_2cta_instrs: bool, + ) -> tuple[int, int]: + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + mma_tiler_mn, cluster_shape_mn, use_2cta_instrs + ) + total_num_clusters = compute_total_num_clusters( + problem_sizes_mnkl, cluster_tile_shape_mn + ) + + # Initialize Stream + current_stream = cutlass_torch.default_stream() + + # Compile grouped GEMM kernel + compiled_grouped_gemm = cute.compile( + grouped_gemm, + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + num_groups, + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + total_num_clusters, + tensor_of_tensormap, + max_active_clusters, + current_stream, + ) + + if not skip_ref_check: + compiled_grouped_gemm( + initial_cute_tensors_abc[0], + initial_cute_tensors_abc[1], + initial_cute_tensors_abc[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc, + tensor_of_ptrs_abc, + tensor_of_tensormap, + current_stream, + ) + + # Compute reference result + for i, (a, b, c) in enumerate(torch_tensors_abc): + ref = torch.einsum( + "mkl,nkl->mnl", + a.cpu().to(dtype=torch.float32), + b.cpu().to(dtype=torch.float32), + ) + print(f"checking group {i}") + torch.testing.assert_close( + c.cpu(), + ref.to(cutlass_torch.dtype(c_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU tensors and create new GPU tensors from them + ( + ptrs_abc_workspace, + torch_tensors_abc_workspace, + cute_tensors_abc_workspace, + strides_abc_workspace, + _, + ) = create_tensors_for_all_groups( + problem_sizes_mnkl, + ab_dtype, + c_dtype, + a_major, + b_major, + c_major, + torch_fp32_tensors_abc, + ) + + initial_cute_tensors_abc_workspace = [ + cute_tensors_abc_workspace[min_a_idx][0], # A with smallest (m, k) + cute_tensors_abc_workspace[min_b_idx][1], # B with smallest (n, k) + cute_tensors_abc_workspace[min_c_idx][2], # C with smallest (m, n) + ] + + # Create new tensors for this workspace + tensor_of_strides_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(strides_abc_workspace, dtype=torch.int32), + cutlass.Int32, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensor_of_ptrs_abc_workspace, _ = cutlass_torch.cute_tensor_like( + torch.tensor(ptrs_abc_workspace, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + assumed_align=16, + ) + + tensormap_workspace, _ = cutlass_torch.cute_tensor_like( + torch.empty(tensormap_shape, dtype=torch.int64), + cutlass.Int64, + is_dynamic_layout=False, + ) + + return testing.JitArguments( + initial_cute_tensors_abc_workspace[0], + initial_cute_tensors_abc_workspace[1], + initial_cute_tensors_abc_workspace[2], + tensor_of_dim_size_mnkl, + tensor_of_strides_abc_workspace, + tensor_of_ptrs_abc_workspace, + tensormap_workspace, + current_stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + sum( + [ + sum( + [ + torch_tensor.numel() * torch_tensor.element_size() + for torch_tensor in group_tensors + ] + ) + for group_tensors in torch_tensors_abc + ] + ) + + + # Add size of strides tensor + tensor_of_strides_abc_torch.numel() + * tensor_of_strides_abc_torch.element_size() + + + # Add size of ptrs tensor + tensor_of_ptrs_abc_torch.numel() * tensor_of_ptrs_abc_torch.element_size() + + + # Add size of tensormap tensor + tensor_of_tensormap_torch.numel() * tensor_of_tensormap_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_grouped_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=current_stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + def parse_comma_separated_tuples(s: str) -> List[tuple[int, ...]]: + if s.strip().startswith("("): + # Split on ),( to separate tuples + tuples = s.strip("()").split("),(") + result = [] + tuple_len = None + + for t in tuples: + # Parse individual tuple + nums = [int(x.strip()) for x in t.split(",")] + + # Validate tuple length consistency + if tuple_len is None: + tuple_len = len(nums) + elif len(nums) != tuple_len: + raise argparse.ArgumentTypeError( + "All tuples must have the same length" + ) + + result.append(tuple(nums)) + return result + + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers or list of tuples" + ) + + parser = argparse.ArgumentParser( + description="Example of Grouped GEMM on Blackwell." + ) + parser.add_argument( + "--num_groups", + type=int, + default=2, + help="Number of groups", + ) + parser.add_argument( + "--problem_sizes_mnkl", + type=parse_comma_separated_tuples, + default=((128, 128, 128, 1), (128, 128, 128, 1)), + help="a tuple of problem sizes for each group (comma-separated tuples)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--tensormap_update_mode", + type=str, + default="SMEM", + help="Tensor map update mode", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=cutlass.Float16) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--use_2cta_instrs", + action="store_true", + help="Enable 2CTA MMA instructions feature", + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if ( + len(args.problem_sizes_mnkl) != 0 + and len(args.problem_sizes_mnkl) != args.num_groups + ): + parser.error("--problem_sizes_mnkl must contain exactly num_groups tuples") + + # l mode must be 1 for all groups + for _, _, _, l in args.problem_sizes_mnkl: + if l != 1: + parser.error("l must be 1 for all groups") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + if args.tensormap_update_mode not in ["GMEM", "SMEM"]: + parser.error("--tensormap_update_mode must be GMEM or SMEM") + + if args.tensormap_update_mode == "GMEM": + tensormap_update_mode = utils.TensorMapUpdateMode.GMEM + else: + tensormap_update_mode = utils.TensorMapUpdateMode.SMEM + + torch.manual_seed(2025) + + run( + args.num_groups, + args.problem_sizes_mnkl, + args.ab_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.use_2cta_instrs, + tensormap_update_mode, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py new file mode 100644 index 0000000000000000000000000000000000000000..829d6b7ecaf456c0b5782a82b0f8621370c1e532 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd.py @@ -0,0 +1,3784 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +from typing import List, Type, Tuple, Optional +import cuda.bindings.driver as cuda + +import torch +import torch.nn.functional as F + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).resolve().parent)) +from mamba2_ssd_reference import ( + ssd_reference_fp32_all, + ssd_reference_lowprecision_intermediates, + analyze_relative_diffs, +) +from mamba2_ssd_tile_scheduler import ( + Mamba2SSDTileSchedulerParams, + Mamba2SSDTileScheduler, +) + + +class SSDKernel: + def __init__( + self, + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + L: int, + D: int, + N: int, + has_d: bool, + d_has_hdim: bool, + ): + self.io_dtype: Type[cutlass.Numeric] = io_dtype + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.cumsum_delta_dtype: Type[cutlass.Numeric] = cumsum_delta_dtype + # has_d means epilog warp performs Y += X*D fusion + self.has_d: bool = has_d + # d_has_hdim = True means D is (D, EH) shape and loaded by TMA + # d_has_hdim = False means D is (1, EH) shape and loaded directly to register + self.d_has_hdim: bool = d_has_hdim + self.tile_shape = (L, D, N) + + assert io_dtype in { + cutlass.Float16, + cutlass.BFloat16, + }, "Do not support other I/O types." + assert acc_dtype in {cutlass.Float32}, "Do not support other ACC types." + assert cumsum_delta_dtype in { + cutlass.Float32 + }, "Do not support other cumsum types." + assert not (not has_d and d_has_hdim), "D cannot have Hdim if has_d is False" + + # Hardcode default setting + self.use_2cta_instrs = False + self.cluster_shape_mnk = (1, 1, 1) + self.epi_tile = (128, 32) + + # Setup mma tile shapes + self.tile_shape_mnk_intra1 = (L, L, N) + self.tile_shape_mnk_intra2 = (L, D, L) + self.tile_shape_mnk_inter1 = (N, D, L) + self.tile_shape_mnk_inter2 = (L, D, N) + + self.cta_group = ( + tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + ) + + # Launch config + self.occupancy = 1 + self.mma_inter_warp_id = 0 + self.mma_intra_warp_id = 1 + self.tma_b_c_warp_id = 2 + self.tma_deltas_x_d_warp_id = 3 + self.pre_inter_warp_id = [4, 5, 6, 7] + self.pre_intra_warp_id = [8, 9, 10, 11] + self.epilog_warp_id = [12, 13, 14, 15] + self.threads_per_cta = 32 * len( + ( + self.mma_inter_warp_id, + self.mma_intra_warp_id, + self.tma_b_c_warp_id, + self.tma_deltas_x_d_warp_id, + *self.pre_inter_warp_id, + *self.pre_intra_warp_id, + *self.epilog_warp_id, + ) + ) + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + # Named barriers + self.pre_inter_sync_bar_id = 1 + self.epilog_sync_bar_id = 2 + self.pre_intra_sync_bar_id = 3 + self.tmem_dealloc_sync_bar_id = 4 + + # Number of registers used by each warp + self.num_regs_uniform_warps = 24 + self.num_regs_pre_inter_warps = 168 + self.num_regs_pre_intra_warps = 208 + self.num_regs_epilogue_warps = 112 + + # Shared storage + self.shared_storage = None + + # TMEM buffer offsets + self.tmem_intra1_acc_offset = 0 + self.tmem_intra2_q_offset = 0 + self.tmem_intra2_acc_offset = 0 + self.tmem_inter1_acc_offset = 0 + self.tmem_inter2_acc_offset = 0 + self.num_tmem_cols_total = 0 + + def _setup_attributes(self): + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (tiled_mma_intra1.thr_id.shape,), + ) + + # Setup stages + ( + self.input_stages, + self.output_stages, + self.internal_stages, + self.intra1_acc_stages, + ) = self._compute_stages( + self.smem_capacity, + ) + + # Setup smem layouts + # X is B operand (from smem) of INTRA2_MMA and INTER1_MMA + self.x_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.input_stages, + ) + self.num_x_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.x_smem_layout, (None, None, None, 0)) + ) + + # XT is same shape as ACC operand of INTER2_MMA, before postprocessing by EPILOG + self.xt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_intra2[:2], + self.input_stages, + ) + + # B is B operand (from smem) of INTRA1_MMA + self.b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_b_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.b_smem_layout, (None, None, None, 0)) + ) + + # B_INTERNAL is also A operand (from smem) of INTER1_MMA, after preprocessed by PRE_INTER + self.bt_internal_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + self.io_dtype, + self.internal_stages, + ) + + # B needs to be proprocessed to be used as A operand of INTER1_MMA + self.bt_smem_layout = cute.coalesce( + sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + (self.tile_shape_mnk_inter1[0], self.tile_shape_mnk_inter1[2]), + self.input_stages, + ), + target_profile=(1, 1, 1), + ) + + # C is A operand (from smem) of INTRA1_MMA and INTER2_MMA + self.c_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + self.io_dtype, + self.input_stages, + ) + self.num_c_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.c_smem_layout, (None, None, None, 0)) + ) + + # P is B operand (from smem) of INTER2_MMA, after preprocessed by PRE_INTER + self.p_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.io_dtype, + self.internal_stages, + ) + + # PT is ACC operand (from tmem) of INTER1_MMA, after postprocessed by PRE_INTER + self.pt_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.tile_shape_mnk_inter1[:2], + self.internal_stages, + ) + + # Q is A operand (from tmem) of INTRA2_MMA, after preprocessed by PRE_INTRA + self.q_tmem_layout = sm100_utils.make_smem_layout_a( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + self.io_dtype, + self.internal_stages, + ) + + # P is ACC operand (from tmem) of INTER1_MMA, to be TMA stored by PRE_INTER + self.p_smem_layout_store = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.ROW_MAJOR, + self.tile_shape_mnk_inter2[1:], + self.internal_stages, + ) + + # Y is ACC operand (from smem) of INTER2_MMA and INTRA2_MMA, after postprocessed and TMA stored by EPILOG + self.y_smem_layout = sm100_utils.make_smem_layout_epi( + self.io_dtype, + utils.LayoutEnum.COL_MAJOR, + self.epi_tile, + self.output_stages, + ) + + # Delta is linear smem layouts for pre/post processing + self.delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_delta_load_bytes = cute.size_in_bytes( + self.io_dtype, cute.slice_(self.delta_linear_smem_layout, (None, 0)) + ) + + # Cumsum delta is linear smem layouts for pre/post processing + self.cumsum_delta_linear_smem_layout = cute.make_layout( + (self.tile_shape_mnk_inter1[2], self.input_stages) + ) + self.num_cumsum_delta_load_bytes = cute.size_in_bytes( + self.cumsum_delta_dtype, + cute.slice_(self.cumsum_delta_linear_smem_layout, (None, 0)), + ) + + # D is linear smem layouts when d_has_hdim is True + self.d_linear_smem_layout = ( + cute.make_layout((self.tile_shape_mnk_inter2[1], self.input_stages)) + if self.d_has_hdim + else None + ) + self.num_d_load_bytes = ( + cute.size_in_bytes( + self.io_dtype, + cute.slice_(self.d_linear_smem_layout, (None, 0)), + ) + if self.d_has_hdim + else 0 + ) + + # Setup tmem offsets + ( + self.tmem_intra1_acc_offset, + self.tmem_intra2_q_offset, + self.tmem_intra2_acc_offset, + self.tmem_inter1_acc_offset, + self.tmem_inter2_acc_offset, + self.num_tmem_cols_total, + ) = self._plan_tmem_offsets( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + self.internal_stages, + self.q_tmem_layout, + self.io_dtype, + self.internal_stages, + self.intra1_acc_stages, + ) + + return + + @cute.jit + def __call__( + self, + x: cute.Tensor, + cumsum_delta: cute.Tensor, + delta: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + y: cute.Tensor, + fstate: cute.Tensor, + d: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + self._setup_attributes() + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup TMA atoms and convert TMA tensors + # TMA load for A + x_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra2.thr_id + ) + tma_atom_x, tma_tensor_x = cute.nvgpu.make_tiled_tma_atom_B( + x_op, + x, + cute.slice_(self.x_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra2, + tiled_mma_intra2, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if x.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for B + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + cute.slice_(self.b_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if b.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for C + c_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, tiled_mma_intra1.thr_id + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.make_tiled_tma_atom_A( + c_op, + c, + cute.slice_(self.c_smem_layout, (None, None, None, 0)), + self.tile_shape_mnk_intra1, + tiled_mma_intra1, + self.cluster_layout_vmnk.shape, + internal_type=( + cutlass.TFloat32 if c.element_type is cutlass.Float32 else None + ), + ) + + # TMA load for delta + # TODO: use bulkcp instead of tma + delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(delta.shape), (None, 0, 0, 0) + ) + delta_linear_smem_layout = cute.slice_(self.delta_linear_smem_layout, (None, 0)) + tma_atom_delta, tma_tensor_delta = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + delta, + delta_linear_smem_layout, + delta_cta_v_layout, + ) + + # TMA load for cumsum_delta + cumsum_delta_cta_v_layout = cute.slice_( + cute.make_identity_layout(cumsum_delta.shape), (None, 0, 0, 0) + ) + cumsum_delta_linear_smem_layout = cute.slice_( + self.cumsum_delta_linear_smem_layout, (None, 0) + ) + ( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + cumsum_delta, + cumsum_delta_linear_smem_layout, + cumsum_delta_cta_v_layout, + ) + + tma_atom_d = None + tma_tensor_d = d + # TMA load for D + if cutlass.const_expr(self.d_has_hdim): + d_cta_v_layout = cute.slice_(cute.make_identity_layout(d.shape), (None, 0)) + d_linear_smem_layout = cute.slice_(self.d_linear_smem_layout, (None, 0)) + ( + tma_atom_d, + tma_tensor_d, + ) = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + d, + d_linear_smem_layout, + d_cta_v_layout, + ) + + # TMA store for y + y_cta_v_layout = cute.composition( + cute.make_identity_layout(y.shape), self.epi_tile + ) + y_smem_layout = cute.slice_(self.y_smem_layout, (None, None, 0)) + tma_atom_y, tma_tensor_y = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + y, + y_smem_layout, + y_cta_v_layout, + ) + + # TMA store for fstate(p) + p_cta_v_layout = cute.slice_( + cute.make_identity_layout(fstate.shape), (None, None, 0, 0) + ) + p_smem_layout_store = cute.slice_(self.p_smem_layout_store, (None, None, 0)) + tma_atom_p, tma_tensor_p = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + fstate, + p_smem_layout_store, + p_cta_v_layout, + ) + + # Compute grid size + tile_sched_params, grid = self._compute_grid(y, b, max_active_clusters) + + # Plan shared memory storage + swizzle_buffer_align_bytes = 1024 + nonswizzle_buffer_align_bytes = 128 + + @cute.struct + class SharedStorage: + # Input stage barriers + x_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + x_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + b_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + c_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + deltas_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_full: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + d_empty: cute.struct.MemRange[cutlass.Int64, self.input_stages] # type: ignore + # Intra1 acc stage barriers + intra1_acc_full: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + intra1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.intra1_acc_stages] # type: ignore + # Internal stage barriers + intra2_q_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_q_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + intra2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_b_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter1_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_p_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_full: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + inter2_acc_empty: cute.struct.MemRange[cutlass.Int64, self.internal_stages] # type: ignore + # Tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # Smem tensors + smem_x: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.x_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_b: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.b_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_bt_internal: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.bt_internal_smem_layout) + ], + swizzle_buffer_align_bytes, + ] + smem_c: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.c_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_p: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.p_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_y: cute.struct.Align[ + cute.struct.MemRange[self.io_dtype, cute.cosize(self.y_smem_layout)], + swizzle_buffer_align_bytes, + ] + smem_cumsum_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.cumsum_delta_dtype, + cute.cosize(self.cumsum_delta_linear_smem_layout), + ], + nonswizzle_buffer_align_bytes, + ] + smem_delta: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, cute.cosize(self.delta_linear_smem_layout) + ], + nonswizzle_buffer_align_bytes, + ] + smem_d: cute.struct.Align[ + cute.struct.MemRange[ + self.io_dtype, + cute.cosize(self.d_linear_smem_layout) if self.d_has_hdim else 0, + ], + nonswizzle_buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + if cutlass.const_expr(self.shared_storage.size_in_bytes() > self.smem_capacity): + raise ValueError( + f"SharedStorage size {self.shared_storage.size_in_bytes()} exceeds smem_capacity {self.smem_capacity}" + ) + + # Launch the kernel synchronously + self.kernel( + tma_atom_x, + tma_tensor_x, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tma_atom_p, + tma_tensor_p, + tma_atom_y, + tma_tensor_y, + tma_atom_delta, + tma_tensor_delta, + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + tma_atom_d, + tma_tensor_d, + self.cluster_layout_vmnk, + self.x_smem_layout, + self.xt_smem_layout, + self.b_smem_layout, + self.bt_smem_layout, + self.bt_internal_smem_layout, + self.c_smem_layout, + self.pt_smem_layout, + self.p_smem_layout, + self.q_tmem_layout, + self.p_smem_layout_store, + self.y_smem_layout, + self.delta_linear_smem_layout, + self.cumsum_delta_linear_smem_layout, + self.d_linear_smem_layout, + self.epi_tile, + tile_sched_params, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + min_blocks_per_mp=1, + stream=stream, + ) + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_x: cute.CopyAtom, + tma_tensor_x: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_c: cute.CopyAtom, + tma_tensor_c: cute.Tensor, + tma_atom_p: cute.CopyAtom, + tma_tensor_p: cute.Tensor, + tma_atom_y: cute.CopyAtom, + tma_tensor_y: cute.Tensor, + tma_atom_delta: cute.CopyAtom, + tma_tensor_delta: cute.Tensor, + tma_atom_cumsum_delta: cute.CopyAtom, + tma_tensor_cumsum_delta: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + tma_tensor_d: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + x_smem_layout: cute.ComposedLayout, + xt_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + bt_smem_layout: cute.ComposedLayout, + bt_internal_smem_layout: cute.ComposedLayout, + c_smem_layout: cute.ComposedLayout, + pt_smem_layout: cute.ComposedLayout, + p_smem_layout: cute.ComposedLayout, + q_tmem_layout: cute.ComposedLayout, + p_smem_layout_store: cute.ComposedLayout, + y_smem_layout: cute.ComposedLayout, + delta_linear_smem_layout: cute.Layout, + cumsum_delta_linear_smem_layout: cute.Layout, + d_linear_smem_layout: Optional[cute.Layout], + epi_tile: cute.Tile, + tile_sched_params: Mamba2SSDTileSchedulerParams, + ): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # Prefetch tma descriptor + if warp_idx == 0: + tma_atoms = [ + tma_atom_x, + tma_atom_b, + tma_atom_c, + tma_atom_p, + tma_atom_y, + tma_atom_delta, + tma_atom_cumsum_delta, + ] + if cutlass.const_expr(self.d_has_hdim): + tma_atoms.append(tma_atom_d) + for tma_atom in tma_atoms: + cpasync.prefetch_descriptor(tma_atom) + + # Static consts + D = cute.size(tma_tensor_x, mode=[0]) + L = cute.size(tma_tensor_x, mode=[1]) + N = cute.size(tma_tensor_b, mode=[1]) + # Dynamic values + C = cute.size(tma_tensor_x, mode=[2]) + EH = cute.size(tma_tensor_x, mode=[3]) + B = cute.size(tma_tensor_x, mode=[4]) + G = cute.size(tma_tensor_b, mode=[3]) + NGROUP_RATIO = EH // G + + # Make TiledMma + ( + tiled_mma_intra1, + tiled_mma_intra2, + tiled_mma_inter1, + tiled_mma_inter2, + ) = self.make_tiled_mmas( + self.io_dtype, + self.acc_dtype, + self.cta_group, + self.tile_shape_mnk_intra1, + self.tile_shape_mnk_intra2, + self.tile_shape_mnk_inter1, + self.tile_shape_mnk_inter2, + ) + + # Setup cta/thread coordinates + # Block coord + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma_intra1.thr_id.shape) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + # Workload coord + tile_sched = Mamba2SSDTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + # Thread/warp coord + tidx, _, _ = cute.arch.thread_idx() + # Thread coord inside specialized warps + local_tidx = tidx % 128 + local_warp_idx = cute.arch.make_warp_uniform(local_tidx // 32) + + # Alloc and init smem tensors and pipelines + smem = utils.SmemAllocator() + smem_storage = smem.allocate(self.shared_storage) + + # Setup smem tensors + smem_x = smem_storage.smem_x.get_tensor( + x_smem_layout.outer, swizzle=x_smem_layout.inner + ) + smem_xt = smem_storage.smem_x.get_tensor( + xt_smem_layout.outer, swizzle=xt_smem_layout.inner + ) + smem_b = smem_storage.smem_b.get_tensor( + b_smem_layout.outer, swizzle=b_smem_layout.inner + ) + smem_bt = smem_storage.smem_b.get_tensor( + bt_smem_layout.outer, swizzle=bt_smem_layout.inner + ) + smem_bt_internal = smem_storage.smem_bt_internal.get_tensor( + bt_internal_smem_layout.outer, swizzle=bt_internal_smem_layout.inner + ) + smem_c = smem_storage.smem_c.get_tensor( + c_smem_layout.outer, swizzle=c_smem_layout.inner + ) + smem_p = smem_storage.smem_p.get_tensor( + p_smem_layout.outer, swizzle=p_smem_layout.inner + ) + smem_pt = smem_storage.smem_p.get_tensor( + pt_smem_layout.outer, swizzle=pt_smem_layout.inner + ) + smem_p_store = smem_storage.smem_p.get_tensor( + p_smem_layout_store.outer, swizzle=p_smem_layout_store.inner + ) + smem_y = smem_storage.smem_y.get_tensor( + y_smem_layout.outer, swizzle=y_smem_layout.inner + ) + smem_cumsum_delta = smem_storage.smem_cumsum_delta.get_tensor( + cumsum_delta_linear_smem_layout + ) + smem_delta = smem_storage.smem_delta.get_tensor(delta_linear_smem_layout) + smem_d = None + if cutlass.const_expr(self.d_has_hdim): + smem_d = smem_storage.smem_d.get_tensor(d_linear_smem_layout) + + # Init mbarrier for pipeline + x_pipeline = self.make_and_init_x_pipeline(smem_storage.x_full.data_ptr()) + b_pipeline = self.make_and_init_b_pipeline(smem_storage.b_full.data_ptr()) + c_pipeline = self.make_and_init_c_pipeline(smem_storage.c_full.data_ptr()) + deltas_pipeline = self.make_and_init_deltas_pipeline( + smem_storage.deltas_full.data_ptr() + ) + d_pipeline = self.make_and_init_d_pipeline(smem_storage.d_full.data_ptr()) + intra1_acc_pipeline = self.make_and_init_intra1_acc_pipeline( + smem_storage.intra1_acc_full.data_ptr() + ) + intra2_q_pipeline = self.make_and_init_intra2_q_pipeline( + smem_storage.intra2_q_full.data_ptr() + ) + intra2_acc_pipeline = self.make_and_init_intra2_acc_pipeline( + smem_storage.intra2_acc_full.data_ptr() + ) + inter1_b_pipeline = self.make_and_init_inter1_b_pipeline( + smem_storage.inter1_b_full.data_ptr() + ) + inter1_acc_pipeline = self.make_and_init_inter1_acc_pipeline( + smem_storage.inter1_acc_full.data_ptr() + ) + inter2_p_pipeline = self.make_and_init_inter2_p_pipeline( + smem_storage.inter2_p_full.data_ptr() + ) + inter2_acc_pipeline = self.make_and_init_inter2_acc_pipeline( + smem_storage.inter2_acc_full.data_ptr() + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_arrive_relaxed() + + # Cluster wait before tmem alloc + if cute.size(self.cluster_shape_mnk) > 1: + cute.arch.cluster_wait() + + # Alloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_cols_total, + smem_storage.tmem_holding_buf, + is_two_cta=self.use_2cta_instrs, + ) + + # Bar sync before retrieving tmem ptr from shared mem + cute.arch.barrier() + + # Retrieve tmem ptr + tmem_ptr_base = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=smem_storage.tmem_holding_buf, + ) + + # Specialized TMA load Delta/CumsumDelta/X warp + if warp_idx == self.tma_deltas_x_d_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = self.tma_partition_with_shape( + tma_atom_delta, + tma_tensor_delta, + smem_delta, + (self.tile_shape_mnk_inter1[2],), + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + ( + tDeltasCumsumDelta, + tDeltagCumsumDelta_pre_slice, + ) = self.tma_partition_with_shape( + tma_atom_cumsum_delta, + tma_tensor_cumsum_delta, + smem_cumsum_delta, + (self.tile_shape_mnk_inter1[2],), + ) + + tDsD = None + tDgD_pre_slice = None + if cutlass.const_expr(self.d_has_hdim): + # Partition global/shared tensor for D + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, EH) + tDsD, tDgD_pre_slice = self.tma_partition_with_shape( + tma_atom_d, tma_tensor_d, smem_d, (self.tile_shape_mnk_inter2[1],) + ) + + # Pipeline X/Delta/CumsumDelta/D producer state + x_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + deltas_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + d_producer_state = None + if cutlass.const_expr(self.d_has_hdim): + # D is loaded by TMA only when d_has_hdim is True + d_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tXgX = tXgX_pre_slice[None, 0, 0, None, eh_idx, b_idx] + tDeltagDelta = tDeltagDelta_pre_slice[None, 0, None, eh_idx, b_idx] + tDeltagCumsumDelta = tDeltagCumsumDelta_pre_slice[ + None, 0, None, eh_idx, b_idx + ] + tDgD = None + if cutlass.const_expr(self.d_has_hdim): + # ((ATOM_V, REST_V)) + tDgD = tDgD_pre_slice[None, 0, eh_idx] + + # Reset count for pipeline state + x_producer_state.reset_count() + deltas_producer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_producer_state.reset_count() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + # Wait for D buffer empty + d_pipeline.producer_acquire(d_producer_state) + # TMA load D + cute.copy( + tma_atom_d, + tDgD, + tDsD[None, d_producer_state.index], + tma_bar_ptr=d_pipeline.producer_get_barrier(d_producer_state), + ) + # Advance D producer state + d_producer_state.advance() + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for X buffer empty + x_pipeline.producer_acquire(x_producer_state, peek_x_empty_status) + + # TMA load X + cute.copy( + tma_atom_x, + tXgX[None, x_producer_state.count], + tXsX[None, x_producer_state.index], + tma_bar_ptr=x_pipeline.producer_get_barrier(x_producer_state), + ) + + # Conditionally wait for deltas buffer empty + deltas_pipeline.producer_acquire( + deltas_producer_state, peek_deltas_empty_status + ) + + # TMA load Delta/CumsumDelta + cute.copy( + tma_atom_delta, + tDeltagDelta[None, deltas_producer_state.count], + tDeltasDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + cute.copy( + tma_atom_cumsum_delta, + tDeltagCumsumDelta[None, deltas_producer_state.count], + tDeltasCumsumDelta[None, deltas_producer_state.index], + tma_bar_ptr=deltas_pipeline.producer_get_barrier( + deltas_producer_state + ), + ) + + # Advance X/deltas producer state + x_producer_state.advance() + deltas_producer_state.advance() + + # Peek (try_wait) X/deltas buffer empty status + peek_x_empty_status = self.conditional_producer_try_acquire( + x_producer_state, x_pipeline, C + ) + peek_deltas_empty_status = self.conditional_producer_try_acquire( + deltas_producer_state, deltas_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for X/Deltas/D + x_pipeline.producer_tail(x_producer_state) + deltas_pipeline.producer_tail(deltas_producer_state) + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.producer_tail(d_producer_state) + # END of specialized tma load X/Deltas/D warp + + # Specialized TMA load B/C warp + elif warp_idx == self.tma_b_c_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tBsB, tBgB_pre_slice = self.tma_partition_for_mma_b_operand( + tma_atom_b, + tma_tensor_b, + smem_b, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = self.tma_partition_for_mma_a_operand( + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ) + + # Pipeline B/C producer state + b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + c_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), C) + tBgB = tBgB_pre_slice[None, 0, 0, None, g_idx, b_idx] + tCgC = tCgC_pre_slice[None, 0, 0, None, g_idx, b_idx] + + # Reset count for pipeline state + b_producer_state.reset_count() + c_producer_state.reset_count() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + + # Batched load over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B buffer empty + b_pipeline.producer_acquire(b_producer_state, peek_b_empty_status) + + # TMA load B + cute.copy( + tma_atom_b, + tBgB[None, b_producer_state.count], + tBsB[None, b_producer_state.index], + tma_bar_ptr=b_pipeline.producer_get_barrier(b_producer_state), + ) + + # Conditionally wait for C buffer empty + c_pipeline.producer_acquire(c_producer_state, peek_c_empty_status) + + # TMA load C + cute.copy( + tma_atom_c, + tCgC[None, c_producer_state.count], + tCsC[None, c_producer_state.index], + tma_bar_ptr=c_pipeline.producer_get_barrier(c_producer_state), + ) + + # Advance B/C producer state + b_producer_state.advance() + c_producer_state.advance() + + # Peek (try_wait) B/C buffer empty status + peek_b_empty_status = self.conditional_producer_try_acquire( + b_producer_state, b_pipeline, C + ) + peek_c_empty_status = self.conditional_producer_try_acquire( + c_producer_state, c_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for B/C + b_pipeline.producer_tail(b_producer_state) + c_pipeline.producer_tail(c_producer_state) + # END of specialized tma load B/C warp + + # Specialized MMA Intra warp + elif warp_idx == self.mma_intra_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTRA_MMA1 B/C/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCrC, tCrB, tCtAccIntra1 = self.mma_partition_ss( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + smem_c, + smem_b, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + + # Make shared/tmem fragments for INTRA_MMA2 X/Q/ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrQ, tCrX, tCtAccIntra2 = self.mma_partition_ts( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + q_tmem_layout, + smem_x, + tmem_ptr_base + self.tmem_intra2_q_offset, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + + # Pipeline B/C/X/INTRA2_Q consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_q_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTRA1_ACC/INTRA2_ACC producer state + intra1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.intra1_acc_stages + ) + intra2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + b_consumer_state.reset_count() + c_consumer_state.reset_count() + intra1_acc_producer_state.reset_count() + x_consumer_state.reset_count() + intra2_q_consumer_state.reset_count() + intra2_acc_producer_state.reset_count() + + # Peek (try_wait) B/C/X/INTRA1_ACC buffer full/full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Manual pipeline: unrolled INTRA_MMA1 chunk_idx = 0 loop + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Advance B/C/INTRA1_ACC state + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + + # Manual pipeline: batched gemm over C-1 dimension + for chunk_idx in cutlass.range(C - 1, unroll=1): + # Conditionally wait for B/C/INTRA1_ACC buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + intra1_acc_pipeline.producer_acquire( + intra1_acc_producer_state, peek_wr_intra1_acc_empty_status + ) + + # INTRA_MMA1 + tiled_mma_intra1 = self.exec_mma( + tiled_mma_intra1, + tCtAccIntra1, + tCrC, + tCrB, + intra1_acc_producer_state, + c_consumer_state, + b_consumer_state, + ) + + # Async arrive B/C/INTRA1_ACC buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + c_pipeline.consumer_release(c_consumer_state) + intra1_acc_pipeline.producer_commit(intra1_acc_producer_state) + + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance B/C/INTRA1_ACC cstate + b_consumer_state.advance() + c_consumer_state.advance() + intra1_acc_producer_state.advance() + + # Peek (try_wait) B/C/INTRA1_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_wr_intra1_acc_empty_status = ( + self.conditional_producer_try_acquire( + intra1_acc_producer_state, intra1_acc_pipeline, C + ) + ) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C-1, unroll=1) + + # Manual pipeline: unrolled INTRA_MMA2 chunk_idx = C-1 loop + # Conditionally wait for X/INTRA2_Q/INTRA2_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state, peek_x_full_status) + intra2_q_pipeline.consumer_wait(intra2_q_consumer_state) + intra2_acc_pipeline.producer_acquire(intra2_acc_producer_state) + + # INTRA_MMA2 + tiled_mma_intra2 = self.exec_mma( + tiled_mma_intra2, + tCtAccIntra2, + tCrQ, + tCrX, + intra2_acc_producer_state, + intra2_q_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTRA2_Q/INTRA2_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + intra2_q_pipeline.consumer_release(intra2_q_consumer_state) + intra2_acc_pipeline.producer_commit(intra2_acc_producer_state) + + # Advance X/INTRA2_Q/INTRA2_ACC state + x_consumer_state.advance() + intra2_q_consumer_state.advance() + intra2_acc_producer_state.advance() + + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA1_ACC/INTRA2_ACC + intra1_acc_pipeline.producer_tail(intra1_acc_producer_state) + intra2_acc_pipeline.producer_tail(intra2_acc_producer_state) + # END of specialized mma-intra warp + + # Specialized MMA Inter warp + elif warp_idx == self.mma_inter_warp_id: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + + # Make shared/tmem fragments for INTER_MMA1 X/B/ACC + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrB, tCrX, tCtAccInter1 = self.mma_partition_ss( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + smem_bt_internal, + smem_x, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + + # Make shared/tmem fragments for INTER_MMA2 C/P/ACC + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + # (MMA, MMA_N, MMA_K, INTERNAL_STAGE) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCrC, tCrP, tCtAccInter2 = self.mma_partition_ss( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + smem_c, + smem_p, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + + # Pipeline X/C/INTER1_B/INTER2_P consumer state + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + c_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_p_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_ACC/INTER2_ACC producer state + inter1_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + x_consumer_state.reset_count() + c_consumer_state.reset_count() + inter1_acc_producer_state.reset_count() + inter1_b_consumer_state.reset_count() + inter2_p_consumer_state.reset_count() + inter2_acc_producer_state.reset_count() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty status + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + + # Batched gemm over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for C/INTER2_P/INTER2_ACC buffer full/full/empty + c_pipeline.consumer_wait(c_consumer_state, peek_c_full_status) + inter2_p_pipeline.consumer_wait( + inter2_p_consumer_state, peek_inter2_p_full_status + ) + inter2_acc_pipeline.producer_acquire( + inter2_acc_producer_state, peek_inter2_acc_empty_status + ) + + # INTER MMA2 + tiled_mma_inter2 = self.exec_mma( + tiled_mma_inter2, + tCtAccInter2, + tCrC, + tCrP, + inter2_acc_producer_state, + c_consumer_state, + inter2_p_consumer_state, + ) + + # Async arrive C/INTER2_P/INTER2_ACC buffer empty/empty/full + c_pipeline.consumer_release(c_consumer_state) + inter2_p_pipeline.consumer_release(inter2_p_consumer_state) + inter2_acc_pipeline.producer_commit(inter2_acc_producer_state) + + # Wait for X/INTER1_B/INTER1_ACC buffer full/full/empty + x_pipeline.consumer_wait(x_consumer_state) + inter1_b_pipeline.consumer_wait(inter1_b_consumer_state) + inter1_acc_pipeline.producer_acquire(inter1_acc_producer_state) + + # INTER MMA1 + tiled_mma_inter1 = self.exec_mma( + tiled_mma_inter1, + tCtAccInter1, + tCrB, + tCrX, + inter1_acc_producer_state, + inter1_b_consumer_state, + x_consumer_state, + ) + + # Async arrive X/INTER1_B/INTER1_ACC buffer empty/empty/full + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, pipeline.PipelineOp.TCGen05Mma + ) + else: + x_pipeline.consumer_release(x_consumer_state) + inter1_b_pipeline.consumer_release(inter1_b_consumer_state) + inter1_acc_pipeline.producer_commit(inter1_acc_producer_state) + + # Advance X/C/INTER1_B/INTER1_ACC/INTER2_P/INTER2_ACC state + x_consumer_state.advance() + c_consumer_state.advance() + inter1_b_consumer_state.advance() + inter1_acc_producer_state.advance() + inter2_p_consumer_state.advance() + inter2_acc_producer_state.advance() + + # Peek (try_wait) C/INTER2_P/INTER2_ACC buffer full/full/empty for chunk_idx = chunk_idx + 1 + peek_c_full_status = self.conditional_consumer_try_wait( + c_consumer_state, c_pipeline, C + ) + peek_inter2_p_full_status = self.conditional_consumer_try_wait( + inter2_p_consumer_state, inter2_p_pipeline, C + ) + peek_inter2_acc_empty_status = ( + self.conditional_producer_try_acquire( + inter2_acc_producer_state, inter2_acc_pipeline, C + ) + ) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for INTER1_ACC/INTER2_ACC + inter1_acc_pipeline.producer_tail(inter1_acc_producer_state) + inter2_acc_pipeline.producer_tail(inter2_acc_producer_state) + + # Specialized Pre-Inter warp + elif ( + warp_idx == self.pre_inter_warp_id[0] + or warp_idx == self.pre_inter_warp_id[1] + or warp_idx == self.pre_inter_warp_id[2] + or warp_idx == self.pre_inter_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_inter_warps) + + # Make tiledCopy and partition smem/register tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tiled_s2r_b, tBsB_s2r, tBrB_s2r = self.pre_inter_smem_load_and_partition_b( + local_tidx, smem_bt + ) + + # Partition shared tensor for smem store Bt + smem_bt_internal_ = cute.make_tensor( + smem_bt_internal.iterator, smem_bt.layout + ) + # Make tiledCopy and partition register/smem tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_b, tBrB_r2s, tBsB_r2s = self.pre_inter_smem_store_and_partition_b( + local_tidx, + smem_bt_internal_, + tiled_s2r_b, + tBrB_s2r, + ) + + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDelta = self.pre_inter_make_delta(smem_delta, smem_bt.layout) + sDeltaA = self.pre_inter_make_delta(smem_cumsum_delta, smem_bt.layout) + + # Make copy_atom and partition register/smem tensor for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + ( + s2r_atom_delta, + tBsDelta_s2r, + tBrDelta_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDelta, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tBsDeltaA_s2r, + tBrDeltaA_s2r, + ) = self.smem_load_and_partition_delta_d( + tiled_s2r_b, local_tidx, sDeltaA, (None, None, None, 0) + ) + + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + tBrDelta_r2s = thr_r2s_b.retile(tBrDelta_s2r) + tBrDeltaA_r2s = thr_r2s_b.retile(tBrDeltaA_s2r) + + # Make tmem fragment for INTER1_ACC + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAccInter1 = self.mma_partition_c( + tiled_mma_inter1, + self.tile_shape_mnk_inter1, + tmem_ptr_base + self.tmem_inter1_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter1 = tCtAccInter1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + tiled_t2r_inter1, + tTR_tP, + tTR_rP, + ) = self.pre_inter_tmem_load_and_partition_p(local_tidx, tInter1, smem_pt) + + # Make fragment for register to hold P after post-processing (in acc dtype) + tState = cute.make_fragment(tTR_rP.shape, self.acc_dtype) + + # Make tiledCopy and partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tiled_r2s_p, tRS_rP, tRS_sP = self.smem_store_and_partition_p_y( + local_tidx, smem_pt, tiled_t2r_inter1 + ) + + # Partition global/shared tensor for P (State) + # ((ATOM_V, REST_V), INTERNAL_STAGE) + # ((ATOM_V, REST_V), 1, 1, EH, B) + bSG_sP, bSG_gP_pre_slice = self.tma_partition_with_shape( + tma_atom_p, + tma_tensor_p, + smem_p_store, + self.tile_shape_mnk_inter2[1:], + ) + + # Pipeline B/Delta/INTER1_ACC consumer state + b_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + inter1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + + # Pipeline INTER1_B/INTER2_P producer state + inter1_b_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + inter2_p_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + # Pipeline TMA store P + tma_p_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.internal_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ), + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V)) + bSG_gP = bSG_gP_pre_slice[(None, 0, 0, eh_idx, b_idx)] + + # Reset count for pipeline state + b_consumer_state.reset_count() + deltas_consumer_state.reset_count() + inter1_b_producer_state.reset_count() + inter1_acc_consumer_state.reset_count() + inter2_p_producer_state.reset_count() + + # State (P) init + tState.fill(0.0) + + # Peek (try_wait) B/Delta/INTER1_B buffer full/full/empty status + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + + # Prefill INTER2_P with 0 + # Wait for INTER2_P buffer empty + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + tRS_rP.fill(0.0) + # Copy INTER2_P from register to smem + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Async arrive INTER2_P buffer full + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for B/Delta/B_TMEM buffer full/full/empty + b_pipeline.consumer_wait(b_consumer_state, peek_b_full_status) + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + inter1_b_pipeline.producer_acquire( + inter1_b_producer_state, peek_wr_inter1_b_empty_status + ) + + # Load B/Delta/DeltaA/last_column + b_coord = (None, None, None, b_consumer_state.index) + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy(tiled_s2r_b, tBsB_s2r[b_coord], tBrB_s2r) + cute.copy(s2r_atom_delta, tBsDelta_s2r[delta_coord], tBrDelta_s2r) + cute.copy( + s2r_atom_cumsum, tBsDeltaA_s2r[delta_coord], tBrDeltaA_s2r + ) + last_column = smem_cumsum_delta[ + smem_cumsum_delta.shape[0] - 1, deltas_consumer_state.index + ] + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Combine B/Delta/DeltaA/last_column + tScaledB = self.pre_inter_scale_bt_with_delta( + tBrB_s2r, tBrDelta_r2s, tBrDeltaA_r2s, last_column + ) + + # Store scaled B to tBrB_r2s + for reg_idx in range(cute.size(tBrB_r2s)): + tBrB_r2s[reg_idx] = tScaledB[reg_idx].to(self.io_dtype) + + # Store tBrB_r2s to bt_smem_internal + inter1_b_coord = (None, None, None, inter1_b_producer_state.index) + cute.copy(tiled_r2s_b, tBrB_r2s, tBsB_r2s[inter1_b_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive B/Delta/B_TMEM buffer empty/empty/full + b_pipeline.consumer_release( + b_consumer_state, pipeline.PipelineOp.AsyncThread + ) + deltas_pipeline.consumer_release(deltas_consumer_state) + inter1_b_pipeline.producer_commit(inter1_b_producer_state) + + # Wait for INTER1_ACC/INTER2_P buffer full/empty + inter1_acc_pipeline.consumer_wait(inter1_acc_consumer_state) + inter2_p_pipeline.producer_acquire(inter2_p_producer_state) + + # Load INTER1_ACC + inter1_acc_coord = ( + None, + None, + None, + inter1_acc_consumer_state.index, + ) + cute.copy(tiled_t2r_inter1, tTR_tP[inter1_acc_coord], tTR_rP) + + # Fence for TMEM load + cute.arch.fence_view_async_tmem_load() + + # Combine INTER1_ACC/last_column/State + exp_last_column = cute.math.exp(last_column, fastmath=True) + for reg_idx in range(0, cute.size(tTR_rP), 2): + ( + tTR_rP[reg_idx], + tTR_rP[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (exp_last_column, exp_last_column), + (tState[reg_idx], tState[reg_idx + 1]), + (tTR_rP[reg_idx], tTR_rP[reg_idx + 1]), + ) + + # Store scaled P to tRS_rP + for reg_idx in range(cute.size(tTR_rP)): + tRS_rP[reg_idx] = tTR_rP[reg_idx].to(self.io_dtype) + + # Update old state + tState.store(tTR_rP.load()) + + # Store INTER2_P + inter2_p_coord = (None, None, None, inter2_p_producer_state.index) + cute.copy(tiled_r2s_p, tRS_rP, tRS_sP[inter2_p_coord]) + + # Fence for shared memory + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + + # Async arrive INTER1_ACC/INTER2_P buffer empty/full + inter1_acc_pipeline.consumer_release(inter1_acc_consumer_state) + # Last iteration consumer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + inter2_p_pipeline.producer_commit(inter2_p_producer_state) + + # Advance B/Delta/INTER1_B/INTER1_ACC state + b_consumer_state.advance() + deltas_consumer_state.advance() + inter1_b_producer_state.advance() + inter1_acc_consumer_state.advance() + # Peek (try_wait) B/Delta/INTER1_B buffer full/full./empty for chunk_idx = chunk_idx + 1 + peek_b_full_status = self.conditional_consumer_try_wait( + b_consumer_state, b_pipeline, C + ) + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_wr_inter1_b_empty_status = ( + self.conditional_producer_try_acquire( + inter1_b_producer_state, inter1_b_pipeline, C + ) + ) + + # Last iteration producer is PRE_INTER warp itself, not MMA_INTER warp + if inter2_p_producer_state.count < C: + # Advance INTER2_P producer state + inter2_p_producer_state.advance() + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Store last INTER2_P (State) from smem to gmem + # Wait for all previous stores to smem to be done + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + + if local_warp_idx == 0: + # TMA store P + cute.copy( + tma_atom_p, + bSG_sP[(None, inter2_p_producer_state.index)], + bSG_gP, + ) + # Wait for TMA store done + tma_p_pipeline.producer_commit() + tma_p_pipeline.producer_acquire() + + cute.arch.barrier( + barrier_id=self.pre_inter_sync_bar_id, + number_of_threads=len(self.pre_inter_warp_id) * 32, + ) + tma_p_pipeline.producer_tail() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTER1_B/INTER2_P/TMA store P + inter1_b_pipeline.producer_tail(inter1_b_producer_state) + inter2_p_pipeline.producer_tail(inter2_p_producer_state) + # END of specialized pre-inter warp + + # Specialized Pre-Intra warp + elif ( + warp_idx == self.pre_intra_warp_id[0] + or warp_idx == self.pre_intra_warp_id[1] + or warp_idx == self.pre_intra_warp_id[2] + or warp_idx == self.pre_intra_warp_id[3] + ): + # Alloc regs in pre_inter warps + cute.arch.warpgroup_reg_alloc(self.num_regs_pre_intra_warps) + + # Make tmem fragment for INTRA1_ACC + # (MMA, MMA_M, MMA_N, INTRA1_ACC_STAGE) + tCtAccIntra1 = self.mma_partition_c( + tiled_mma_intra1, + self.tile_shape_mnk_intra1, + tmem_ptr_base + self.tmem_intra1_acc_offset, + self.intra1_acc_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTRA1_ACC_STAGE) + tIntra1 = tCtAccIntra1[((None, None), 0, 0, None)] + + # Make tiledCopy and partition tmem/register tensor for tensor memory load INTRA1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tiled_t2r_intra1, tTR_tQ, tTR_rQ = self.pre_intra_tmem_load_and_partition_q( + tIntra1, local_tidx + ) + + # Broadcast delta/delta_cumsum smem tensor from LxINPUT_STAGE to LxLxINPUT_STAGE + sDeltaA_Row = self.pre_intra_make_delta(smem_cumsum_delta, 0) + sDeltaA_Col = self.pre_intra_make_delta(smem_cumsum_delta, 1) + sDelta = self.pre_intra_make_delta(smem_delta, 0) + + # Make tiledCopy and partition smem/register tensor for smem memory load delta/delta_cumsum + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_cumsum, + tQsDeltaA_Row, + tQrDeltaA_Row, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Row, (None, None, None, 0) + ) + ( + s2r_atom_cumsum, + tQsDeltaA_Col, + tQrDeltaA_Col, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDeltaA_Col, (None, None, None, 0) + ) + ( + s2r_atom_delta, + tQsDelta, + tQrDelta, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_intra1, local_tidx, sDelta, (None, None, None, 0) + ) + + # Make and partition coord tensor for delta_cumsum load + # (L, L) + coord_tensor = cute.make_identity_tensor( + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)) + ) + thr_t2r_intra1 = tiled_t2r_intra1.get_slice(local_tidx) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tCoord = thr_t2r_intra1.partition_D(coord_tensor) + + # Make tmem tensor for INTRA2_Q + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrQ = self.mma_partition_a_tmem( + tiled_mma_intra2, + q_tmem_layout, + tmem_ptr_base + self.tmem_intra2_q_offset, + ) + + # Make tiledCopy and partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tiled_r2t_q, tRT_rQ, tRT_tQ = self.pre_intra_tmem_store_and_partition_q( + local_tidx, tCrQ + ) + + # Pipeline DELTA/INTRA1_ACC consumer state + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra1_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.intra1_acc_stages + ) + # Pipeline INTRA2_Q producer state + intra2_q_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.internal_stages + ) + + while work_tile.is_valid_tile: + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra1_acc_consumer_state.reset_count() + intra2_q_producer_state.reset_count() + + # Peek (try_wait) DELTA/INTRA1_ACC buffer full + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA1_ACC buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra1_acc_pipeline.consumer_wait( + intra1_acc_consumer_state, peek_rd_intra1_acc_full_status + ) + + # Load Q from tmem + intra1_coord = (None, None, None, intra1_acc_consumer_state.index) + cute.copy(tiled_t2r_intra1, tTR_tQ[intra1_coord], tTR_rQ) + cute.arch.fence_view_async_tmem_load() + + # Load tQsDeltaA_Row/tQsDeltaA_Col/tQsDelta from smem + delta_coord = (None, None, None, deltas_consumer_state.index) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Row[delta_coord], tQrDeltaA_Row + ) + cute.copy( + s2r_atom_cumsum, tQsDeltaA_Col[delta_coord], tQrDeltaA_Col + ) + cute.copy(s2r_atom_delta, tQsDelta[delta_coord], tQrDelta) + + # SegSum + tRT_rQ = self.pre_intra_segsum( + tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ) + + # Wait for INTRA2_Q buffer empty + # Delay producer_acquire to right before data store + intra2_q_pipeline.producer_acquire(intra2_q_producer_state) + + # Store Q from reg to tmem + q_coord = (None, None, None, None, intra2_q_producer_state.index) + cute.copy(tiled_r2t_q, tRT_rQ, tRT_tQ[q_coord]) + + # Async arrive Delta/INTRA1_ACC buffer empty + intra1_acc_pipeline.consumer_release(intra1_acc_consumer_state) + deltas_pipeline.consumer_release(deltas_consumer_state) + + cute.arch.fence_view_async_tmem_store() + + # Async arrive INTRA2_Q buffer full + intra2_q_pipeline.producer_commit(intra2_q_producer_state) + + # Advance deltas/intra1_acc/intra2_q states + deltas_consumer_state.advance() + intra1_acc_consumer_state.advance() + intra2_q_producer_state.advance() + + # Peek (try_wait) Delta/INTRA1_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra1_acc_full_status = self.conditional_consumer_try_wait( + intra1_acc_consumer_state, intra1_acc_pipeline, C + ) + # END of for chunk_idx in cutlass.range(C, unroll=1) + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + # END of while work_tile.is_valid_tile + + # Producer tail for INTRA2_Q + intra2_q_pipeline.producer_tail(intra2_q_producer_state) + # END of specialized pre-intra warp + + # Specialized Epilogue warp + else: + # Dealloc regs for pre-inter/pre-intra warps + cute.arch.warpgroup_reg_dealloc(self.num_regs_epilogue_warps) + + # (L, D, INPUT_STAGE) + sDeltaA = self.epilog_make_delta(smem_cumsum_delta) + + # Make tmem tensor for INTRA2_ACC/INTER2_ACC + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccIntra2 = self.mma_partition_c( + tiled_mma_intra2, + self.tile_shape_mnk_intra2, + tmem_ptr_base + self.tmem_intra2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tIntra2 = tCtAccIntra2[((None, None), 0, 0, None)] + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCtAccInter2 = self.mma_partition_c( + tiled_mma_inter2, + self.tile_shape_mnk_inter2, + tmem_ptr_base + self.tmem_inter2_acc_offset, + self.internal_stages, + ) + # (M_PER_MMA, N_PER_MMA, INTERNAL_STAGE) + tInter2 = tCtAccInter2[((None, None), 0, 0, None)] + + # Subtiling INTRA2_ACC/INTER2_ACC/Delta/Y + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INTERNAL_STAGE) + tIntra_epi = cute.flat_divide(tIntra2, epi_tile) + tInter_epi = cute.flat_divide(tInter2, epi_tile) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + sDeltaA_epi = cute.flat_divide(sDeltaA, epi_tile) + + # Make tiled copy and partition tmem/reg tensor w.r.t tensor memory load + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INTERNAL_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + ( + tiled_t2r_intra2, + tTR_tIntra, + tTR_rIntra, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tIntra_epi, smem_y) + ( + tiled_t2r_inter2, + tTR_tInter2, + tTR_rInter, + ) = self.epilog_tmem_load_and_partition_acc(local_tidx, tInter_epi, smem_y) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load Delta + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + ( + s2r_atom_delta, + tTR_sDeltaA, + tTR_rDeltaA, + ) = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, sDeltaA_epi, (None, None, None, 0, 0, 0) + ) + + # Make tiled copy and Partition smem/register tensor w.r.t smem store Y + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N, OUTPUT_STAGE) + # ((R2S_ATOM_V, R2S_REST_V), REST_M, REST_N) + tiled_r2s_y, tRS_rY, tRS_sY = self.smem_store_and_partition_p_y( + local_tidx, smem_y, tiled_t2r_inter2 + ) + + tRS_rCompute = cute.make_fragment(tRS_rY.shape, self.acc_dtype) + + tiled_s2r_x = None + tSR_sX = None + tSR_rX = None + if cutlass.const_expr(self.has_d): + # Make TiledCopy/smem/register tensor for smem load X + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + # (R2S_ATOM, R2S_M, R2S_N) + tiled_s2r_x, tSR_sX, tSR_rX = self.epilog_smem_load_and_partition_x( + tiled_t2r_inter2, local_tidx, smem_xt, epi_tile + ) + + tRS_sD = None + tRS_rD = None + s2r_atom_d = None + if cutlass.const_expr(self.d_has_hdim): + # (L, D, INPUT_STAGE) + sD = self.epilog_make_d(smem_d) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, INPUT_STAGE) + tD_sepi = cute.flat_divide(sD, epi_tile) + + # Make tiled copy and partition smem/reg tensor w.r.t smem load D + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N, EPI_M, EPI_N, INPUT_STAGE) + # ((T2R_ATOM_V, T2R_REST_V), REST_M, REST_N) + s2r_atom_d, tRS_sD, tRS_rD = self.smem_load_and_partition_delta_d( + tiled_t2r_inter2, local_tidx, tD_sepi, (None, None, None, 0, 0, 0) + ) + + elif cutlass.const_expr(self.has_d): + tRS_rD = cutlass.Float32(0.0).to(self.io_dtype) + + # Partition global/shared tensor for TMA store Y + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = self.epilog_tma_partition_y( + tma_tensor_y, tma_atom_y, smem_y, epi_tile + ) + + # Make TMA store pipeline Y + tma_y_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.output_stages, + producer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ), + ) + + # Make consumer pipeline states for Delta/INTRA2_ACC/INTER2_ACC/X/D buffer + deltas_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + intra2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + inter2_acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.internal_stages + ) + x_consumer_state = None + if cutlass.const_expr(self.has_d): + x_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + d_consumer_state = None + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.input_stages + ) + + while work_tile.is_valid_tile: + b_idx, eh_idx, g_idx = work_tile.tile_idx + + # Slice global tensor to current tile idx + # ((ATOM_V, REST_V), EPI_M, EPI_N, C) + bSG_gY = bSG_gY_pre_slice[(None, None, None, 0, 0, None, eh_idx, b_idx)] + if cutlass.const_expr(self.has_d and not self.d_has_hdim): + tRS_rD = tma_tensor_d[0, eh_idx] + + # Reset count for pipeline state + deltas_consumer_state.reset_count() + intra2_acc_consumer_state.reset_count() + inter2_acc_consumer_state.reset_count() + if cutlass.const_expr(self.has_d): + x_consumer_state.reset_count() + if cutlass.const_expr(self.d_has_hdim): + d_consumer_state.reset_count() + + # Peek Delta/INTRA2_ACC/INTER2_ACC buffer status + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + peek_rd_x_full_status = None + if cutlass.const_expr(self.has_d): + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_wait(d_consumer_state) + + # Batched processing over C dimension + for chunk_idx in cutlass.range(C, unroll=1): + # Conditionally wait for Delta/INTRA2_ACC/INTER2_ACC/X buffer full + deltas_pipeline.consumer_wait( + deltas_consumer_state, peek_deltas_full_status + ) + intra2_acc_pipeline.consumer_wait( + intra2_acc_consumer_state, peek_rd_intra2_acc_full_status + ) + inter2_acc_pipeline.consumer_wait( + inter2_acc_consumer_state, peek_rd_inter2_acc_full_status + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_wait( + x_consumer_state, peek_rd_x_full_status + ) + # Loop over EPI_M and EPI_N subtiles + for epi_n in range(cute.size(tTR_tIntra, mode=[4])): + for epi_m in range(cute.size(tTR_tIntra, mode=[3])): + epi_iter_cnt = ( + epi_n * cute.size(tTR_tIntra, mode=[3]) + epi_m + ) + epi_buffer_idx = epi_iter_cnt % self.output_stages + + # Load INTRA2_ACC/INTER2_ACC from tmem + subtile_coord = ( + None, + None, + None, + epi_m, + epi_n, + ) + intra2_coord = subtile_coord + ( + intra2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_intra2, + tTR_tIntra[intra2_coord], + tTR_rIntra, + ) + inter2_coord = subtile_coord + ( + inter2_acc_consumer_state.index, + ) + cute.copy( + tiled_t2r_inter2, + tTR_tInter2[inter2_coord], + tTR_rInter, + ) + # Fence for T2R load + cute.arch.fence_view_async_tmem_load() + + # Load Delta from smem + delta_coord = subtile_coord + (deltas_consumer_state.index,) + cute.copy( + s2r_atom_delta, tTR_sDeltaA[delta_coord], tTR_rDeltaA + ) + + # Load X from smem + if cutlass.const_expr(self.has_d): + x_coord = subtile_coord + (x_consumer_state.index,) + cute.copy(tiled_s2r_x, tSR_sX[x_coord], tSR_rX) + + # Load D from smem + if cutlass.const_expr(self.d_has_hdim): + # Load vector D from smem (d_has_hdim = True) + d_coord = subtile_coord + (d_consumer_state.index,) + cute.copy(s2r_atom_d, tRS_sD[d_coord], tRS_rD) + + # Combine INTRA2_ACC/INTER2_ACC/Delta/X/D + for reg_idx in range(0, cute.size(tRS_rCompute), 2): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + (tTR_rInter[reg_idx], tTR_rInter[reg_idx + 1]), + ( + cute.math.exp( + tTR_rDeltaA[reg_idx], fastmath=True + ), + cute.math.exp( + tTR_rDeltaA[reg_idx + 1], fastmath=True + ), + ), + (tTR_rIntra[reg_idx], tTR_rIntra[reg_idx + 1]), + ) + # Fuse Y += X * D + if cutlass.const_expr(self.d_has_hdim): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD[reg_idx].to(self.acc_dtype), + tRS_rD[reg_idx + 1].to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + elif cutlass.const_expr(self.has_d): + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ) = cute.arch.fma_packed_f32x2( + ( + tRS_rD.to(self.acc_dtype), + tRS_rD.to(self.acc_dtype), + ), + ( + tSR_rX[reg_idx].to(self.acc_dtype), + tSR_rX[reg_idx + 1].to(self.acc_dtype), + ), + ( + tRS_rCompute[reg_idx], + tRS_rCompute[reg_idx + 1], + ), + ) + + tRS_rY.store(tRS_rCompute.load().to(self.io_dtype)) + + # Store Y to smem + cute.copy( + tiled_r2s_y, + tRS_rY, + tRS_sY[None, None, None, epi_buffer_idx], + ) + + # Fence for R2S store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # Sync before TMA store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Async arrive Delta/INTRA2_ACC/INTER2_ACC buffer empty + if ( + epi_iter_cnt + == cute.size(tTR_tIntra, mode=[4]) + * cute.size(tTR_tIntra, mode=[3]) + - 1 + ): + deltas_pipeline.consumer_release(deltas_consumer_state) + intra2_acc_pipeline.consumer_release( + intra2_acc_consumer_state + ) + inter2_acc_pipeline.consumer_release( + inter2_acc_consumer_state + ) + if cutlass.const_expr(self.has_d): + x_pipeline.consumer_release( + x_consumer_state, + pipeline.PipelineOp.AsyncThread, + ) + + # TMA store Y to global memory + if local_warp_idx == 0: + cute.copy( + tma_atom_y, + bSG_sY[None, epi_buffer_idx], + bSG_gY[None, epi_m, epi_n, chunk_idx], + ) + + # Commit TMA store + tma_y_pipeline.producer_commit() + # Wait for TMA store + tma_y_pipeline.producer_acquire() + # Sync before smem store + cute.arch.barrier( + barrier_id=self.epilog_sync_bar_id, + number_of_threads=len(self.epilog_warp_id) * 32, + ) + + # Advance deltas/intra2_acc/inter2_acc consumer states + deltas_consumer_state.advance() + intra2_acc_consumer_state.advance() + inter2_acc_consumer_state.advance() + + # Peek (try_wait) Delta/INTRA2_ACC/INTER2_ACC buffer full for chunk_idx = chunk_idx + 1 + peek_deltas_full_status = self.conditional_consumer_try_wait( + deltas_consumer_state, deltas_pipeline, C + ) + peek_rd_intra2_acc_full_status = self.conditional_consumer_try_wait( + intra2_acc_consumer_state, intra2_acc_pipeline, C + ) + peek_rd_inter2_acc_full_status = self.conditional_consumer_try_wait( + inter2_acc_consumer_state, inter2_acc_pipeline, C + ) + + if cutlass.const_expr(self.has_d): + # Advance x consumer states + x_consumer_state.advance() + # Peek (try_wait) X buffer full for chunk_idx = chunk_idx + 1 + peek_rd_x_full_status = self.conditional_consumer_try_wait( + x_consumer_state, x_pipeline, C + ) + + if cutlass.const_expr(self.d_has_hdim): + d_pipeline.consumer_release(d_consumer_state) + d_consumer_state.advance() + + # Advance to next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # Producer tail for TMA store Y + tma_y_pipeline.producer_tail() + + # Dealloc tmem buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.barrier( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + cute.arch.dealloc_tmem( + tmem_ptr_base, + self.num_tmem_cols_total, + is_two_cta=self.use_2cta_instrs, + ) + else: + cute.arch.barrier_arrive( + barrier_id=self.tmem_dealloc_sync_bar_id, + number_of_threads=self.threads_per_cta, + ) + + return + + @staticmethod + def _compute_stages(smem_capacity): + return 2, 2, 1, 2 # input, output, internal, intra1_acc + + @staticmethod + def _compute_grid(y, b, max_active_clusters): + B = cute.size(y, mode=[4]) + EH = cute.size(y, mode=[3]) + G = cute.size(b, mode=[3]) + NGROUP_RATIO = EH // G + num_blocks = B * EH + + tile_sched_params = Mamba2SSDTileSchedulerParams(num_blocks, EH, NGROUP_RATIO) + grid = Mamba2SSDTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + return tile_sched_params, grid + + @staticmethod + def _plan_tmem_offsets( + tiled_mma_intra1, + tile_shape_mnk_intra1, + tiled_mma_intra2, + tile_shape_mnk_intra2, + tiled_mma_inter1, + tile_shape_mnk_inter1, + tiled_mma_inter2, + tile_shape_mnk_inter2, + acc_stages, + intra2_a_tmem_layout, + a_dtype, + internal_stages, + intra1_acc_stages, + ): + SM100_TMEM_CAPACITY_COLUMNS = 512 + BITS_PER_TMEM_COL = 32 + # (MMA, MMA_M, MMA_N) + acc_shape_intra1 = tiled_mma_intra1.partition_shape_C(tile_shape_mnk_intra1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra1_fake = tiled_mma_intra1.make_fragment_C( + cute.append(acc_shape_intra1, intra1_acc_stages) + ) + num_intra1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra1_fake) + assert tile_shape_mnk_intra1[1] * intra1_acc_stages == num_intra1_acc_cols + # (MMA, MMA_N, MMA_K, STAGE) + tCrQ_fake = tiled_mma_intra2.make_fragment_A(intra2_a_tmem_layout.outer.shape) + num_intra2_a_cols = tcgen05.find_tmem_tensor_col_offset(tCrQ_fake) + assert ( + tile_shape_mnk_intra2[2] + * internal_stages + * a_dtype.width + // BITS_PER_TMEM_COL + == num_intra2_a_cols + ) + # (MMA, MMA_M, MMA_N) + acc_shape_intra2 = tiled_mma_intra2.partition_shape_C(tile_shape_mnk_intra2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccIntra2_fake = tiled_mma_intra2.make_fragment_C( + cute.append(acc_shape_intra2, acc_stages) + ) + num_intra2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccIntra2_fake) + assert tile_shape_mnk_intra2[1] * acc_stages == num_intra2_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter1 = tiled_mma_inter1.partition_shape_C(tile_shape_mnk_inter1[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter1_fake = tiled_mma_inter1.make_fragment_C( + cute.append(acc_shape_inter1, acc_stages) + ) + num_inter1_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter1_fake) + assert tile_shape_mnk_inter1[1] * acc_stages == num_inter1_acc_cols + + # (MMA, MMA_M, MMA_N) + acc_shape_inter2 = tiled_mma_inter2.partition_shape_C(tile_shape_mnk_inter2[:2]) + # (MMA, MMA_M, MMA_N) + tCtAccInter2_fake = tiled_mma_inter2.make_fragment_C( + cute.append(acc_shape_inter2, acc_stages) + ) + num_inter2_acc_cols = tcgen05.find_tmem_tensor_col_offset(tCtAccInter2_fake) + assert tile_shape_mnk_inter2[1] * acc_stages == num_inter2_acc_cols + + tmem_intra1_acc_offset = 0 + tmem_intra2_q_offset = tmem_intra1_acc_offset + num_intra1_acc_cols + tmem_intra2_acc_offset = tmem_intra2_q_offset + num_intra2_a_cols + tmem_inter1_acc_offset = tmem_intra2_acc_offset + num_intra2_acc_cols + tmem_inter2_acc_offset = tmem_inter1_acc_offset + num_inter1_acc_cols + num_tmem_cols_total_tmp = tmem_inter2_acc_offset + num_inter2_acc_cols + # Turn num_tmem_cols_total to the nearest power of 2 + num_tmem_cols_total = 1 + while num_tmem_cols_total < num_tmem_cols_total_tmp: + num_tmem_cols_total *= 2 + assert num_tmem_cols_total <= SM100_TMEM_CAPACITY_COLUMNS + + return ( + tmem_intra1_acc_offset, + tmem_intra2_q_offset, + tmem_intra2_acc_offset, + tmem_inter1_acc_offset, + tmem_inter2_acc_offset, + num_tmem_cols_total, + ) + + @staticmethod + def make_tiled_mmas( + io_dtype, + acc_dtype, + cta_group, + tile_shape_mnk_intra1, + tile_shape_mnk_intra2, + tile_shape_mnk_inter1, + tile_shape_mnk_inter2, + ): + tiled_mma_intra1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("mn"), + acc_dtype, + cta_group, + tile_shape_mnk_intra1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_intra2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_intra2[:2], + tcgen05.OperandSource.TMEM, + ) + tiled_mma_inter1 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("k"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter1[:2], + tcgen05.OperandSource.SMEM, + ) + tiled_mma_inter2 = sm100_utils.make_trivial_tiled_mma( + io_dtype, + tcgen05.OperandMajorMode("mn"), + tcgen05.OperandMajorMode("k"), + acc_dtype, + cta_group, + tile_shape_mnk_inter2[:2], + tcgen05.OperandSource.SMEM, + ) + return tiled_mma_intra1, tiled_mma_intra2, tiled_mma_inter1, tiled_mma_inter2 + + def make_and_init_x_pipeline(self, x_full_mbar_ptr): + x_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + if not self.has_d: + x_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group=x_consumer_group, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + else: + x_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len([self.mma_intra_warp_id, self.mma_inter_warp_id]), + ) + x_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=x_producer_group, + consumer_group_umma=x_consumer_group_umma, + consumer_group_async=x_consumer_group_async, + tx_count=self.num_x_load_bytes, + barrier_storage=x_full_mbar_ptr, + ) + + def make_and_init_b_pipeline(self, b_full_mbar_ptr): + b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + b_consumer_group_umma = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + b_consumer_group_async = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineTmaMultiConsumersAsync.create( + num_stages=self.input_stages, + producer_group=b_producer_group, + consumer_group_umma=b_consumer_group_umma, + consumer_group_async=b_consumer_group_async, + tx_count=self.num_b_load_bytes, + barrier_storage=b_full_mbar_ptr, + ) + + def make_and_init_c_pipeline(self, c_full_mbar_ptr): + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_b_c_warp_id]) + ) + c_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id, self.mma_inter_warp_id]) + ) + return pipeline.PipelineTmaUmma.create( + num_stages=self.input_stages, + producer_group=c_producer_group, + consumer_group=c_consumer_group, + tx_count=self.num_c_load_bytes, + barrier_storage=c_full_mbar_ptr, + ) + + def make_and_init_deltas_pipeline(self, deltas_full_mbar_ptr): + deltas_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + deltas_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + len( + [*self.pre_inter_warp_id, *self.pre_intra_warp_id, *self.epilog_warp_id] + ), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=deltas_producer_group, + consumer_group=deltas_consumer_group, + tx_count=self.num_delta_load_bytes + self.num_cumsum_delta_load_bytes, + barrier_storage=deltas_full_mbar_ptr, + ) + + def make_and_init_d_pipeline(self, d_full_mbar_ptr): + if not self.d_has_hdim: + return None + else: + d_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.tma_deltas_x_d_warp_id]) + ) + d_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + len(self.epilog_warp_id), + len(self.epilog_warp_id), + ) + + return pipeline.PipelineTmaAsync.create( + num_stages=self.input_stages, + producer_group=d_producer_group, + consumer_group=d_consumer_group, + tx_count=self.num_d_load_bytes, + barrier_storage=d_full_mbar_ptr, + ) + + def make_and_init_intra1_acc_pipeline(self, intra1_acc_full_mbar_ptr): + intra1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.intra1_acc_stages, + producer_group=intra1_acc_producer_group, + consumer_group=intra1_acc_consumer_group, + barrier_storage=intra1_acc_full_mbar_ptr, + ) + + def make_and_init_intra2_q_pipeline(self, intra2_q_full_mbar_ptr): + intra2_q_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_intra_warp_id), 128 + ) + intra2_q_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=intra2_q_producer_group, + consumer_group=intra2_q_consumer_group, + barrier_storage=intra2_q_full_mbar_ptr, + ) + + def make_and_init_intra2_acc_pipeline(self, intra2_acc_full_mbar_ptr): + intra2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_intra_warp_id]) + ) + intra2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=intra2_acc_producer_group, + consumer_group=intra2_acc_consumer_group, + barrier_storage=intra2_acc_full_mbar_ptr, + ) + + def make_and_init_inter1_b_pipeline(self, inter1_b_full_mbar_ptr): + inter1_b_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter1_b_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter1_b_producer_group, + consumer_group=inter1_b_consumer_group, + barrier_storage=inter1_b_full_mbar_ptr, + ) + + def make_and_init_inter1_acc_pipeline(self, inter1_acc_full_mbar_ptr): + inter1_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter1_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter1_acc_producer_group, + consumer_group=inter1_acc_consumer_group, + barrier_storage=inter1_acc_full_mbar_ptr, + ) + + def make_and_init_inter2_p_pipeline(self, inter2_p_full_mbar_ptr): + inter2_p_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.pre_inter_warp_id), 128 + ) + inter2_p_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + return pipeline.PipelineAsyncUmma.create( + num_stages=self.internal_stages, + producer_group=inter2_p_producer_group, + consumer_group=inter2_p_consumer_group, + barrier_storage=inter2_p_full_mbar_ptr, + ) + + def make_and_init_inter2_acc_pipeline(self, inter2_acc_full_mbar_ptr): + inter2_acc_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, len([self.mma_inter_warp_id]) + ) + inter2_acc_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 32 * len(self.epilog_warp_id), 128 + ) + return pipeline.PipelineUmmaAsync.create( + num_stages=self.internal_stages, + producer_group=inter2_acc_producer_group, + consumer_group=inter2_acc_consumer_group, + barrier_storage=inter2_acc_full_mbar_ptr, + ) + + def tma_partition_for_mma_b_operand( + self, + tma_atom_x, + tma_tensor_x, + smem_x, + tiled_mma_intra2, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (D, L, 1, 1, C, EH, B) + gX = cute.local_tile( + tma_tensor_x, + self.tile_shape_mnk_intra2[1:], + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra2 = tiled_mma_intra2.get_slice(mma_tile_coord_v) + # (MMA, MMA_N, MMA_K, 1, 1, C, EH, B) + tCgX = thr_mma_intra2.partition_B(gX) + + # Partition global/shared tensor for X + x_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, EH, B) + tXsX, tXgX_pre_slice = cpasync.tma_partition( + tma_atom_x, + block_in_cluster_coord_vmnk[2], + x_cta_layout, + cute.group_modes(smem_x, 0, 3), + cute.group_modes(tCgX, 0, 3), + ) + return tXsX, tXgX_pre_slice + + def tma_partition_for_mma_a_operand( + self, + tma_atom_c, + tma_tensor_c, + smem_c, + tiled_mma_intra1, + cluster_layout_vmnk, + mma_tile_coord_v, + block_in_cluster_coord_vmnk, + ): + # Local_tile partition global tensors + # (L, N, 1, 1, C, G, B) + gC = cute.local_tile( + tma_tensor_c, + cute.slice_(self.tile_shape_mnk_intra1, (None, 0, None)), + (None, None, None, None, None), + ) + # Partition global tensor with regard to TiledMMA + thr_mma_intra1 = tiled_mma_intra1.get_slice(mma_tile_coord_v) + # (MMA, MMA_M/N, MMA_K, 1, 1, C, G, B) + tCgC = thr_mma_intra1.partition_A(gC) + + # Partition global/shared tensor for TMA C + c_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, 1, C, G, B) + tCsC, tCgC_pre_slice = cpasync.tma_partition( + tma_atom_c, + block_in_cluster_coord_vmnk[1], + c_cta_layout, + cute.group_modes(smem_c, 0, 3), + cute.group_modes(tCgC, 0, 3), + ) + return tCsC, tCgC_pre_slice + + def tma_partition_with_shape( + self, tma_atom_delta, tma_tensor_delta, smem_delta, shape + ): + # Local_tile partition global tensors + # (L, 1, C, EH, B) + gDelta = cute.local_tile( + tma_tensor_delta, + shape, + (None,) * cute.rank(tma_tensor_delta), + ) + # Partition global/shared tensor for DELTA + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), 1, C, EH, B) + tDeltasDelta, tDeltagDelta_pre_slice = cpasync.tma_partition( + tma_atom_delta, + 0, + cute.make_layout(1), + cute.group_modes(smem_delta, 0, cute.rank(shape)), + cute.group_modes(gDelta, 0, cute.rank(shape)), + ) + + return tDeltasDelta, tDeltagDelta_pre_slice + + def mma_partition_ss( + self, + tiled_mma, + tile_shape_mnk, + smem_a, + smem_b, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + tCrA = tiled_mma.make_fragment_A(smem_a) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, ACC_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_ts( + self, + tiled_mma, + tile_shape_mnk, + a_tmem_layout, + smem_b, + tmem_a_ptr, + tmem_acc_ptr, + acc_stages, + ): + # (MMA, MMA_M, MMA_K, INTERNAL_STAGE) + tCrA = self.mma_partition_a_tmem(tiled_mma, a_tmem_layout, tmem_a_ptr) + # (MMA, MMA_N, MMA_K, INPUT_STAGE) + tCrB = tiled_mma.make_fragment_B(smem_b) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = self.mma_partition_c( + tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages + ) + return tCrA, tCrB, tCtAcc + + def mma_partition_a_tmem(self, tiled_mma, a_tmem_layout, tmem_a_ptr): + tCrA_fake = tiled_mma.make_fragment_A(a_tmem_layout.outer.shape) + tCrA = cute.make_tensor( + cute.recast_ptr( + tmem_a_ptr, + dtype=tCrA_fake.element_type, + ), + tCrA_fake.layout, + ) + return tCrA + + def mma_partition_c(self, tiled_mma, tile_shape_mnk, tmem_acc_ptr, acc_stages): + acc_shape = tiled_mma.partition_shape_C(tile_shape_mnk[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, acc_stages)) + # (MMA, MMA_M, MMA_N, INTERNAL_STAGE) + tCtAcc = cute.make_tensor(tmem_acc_ptr, tCtAcc_fake.layout) + return tCtAcc + + @cute.jit + def exec_mma( + self, + tiled_mma, + tCtAcc, + tCrA, + tCrB, + acc_producer_state, + a_consumer_state, + b_consumer_state, + ): + for kphase_idx in cutlass.range(cute.size(tCrB, mode=[2]), unroll_full=True): + # set accu = 1 + tiled_mma.set( + tcgen05.Field.ACCUMULATE, + cutlass.Boolean(kphase_idx != 0), + ) + cute.gemm( + tiled_mma, + tCtAcc[None, None, None, acc_producer_state.index], + tCrA[None, None, kphase_idx, a_consumer_state.index], + tCrB[None, None, kphase_idx, b_consumer_state.index], + tCtAcc[None, None, None, acc_producer_state.index], + ) + return tiled_mma + + @cute.jit + def conditional_consumer_try_wait(self, b_consumer_state, b_pipeline, C): + peek_b_full_status = cutlass.Boolean(1) + if b_consumer_state.count < C: + peek_b_full_status = b_pipeline.consumer_try_wait(b_consumer_state) + return peek_b_full_status + + @cute.jit + def conditional_producer_try_acquire( + self, intra1_acc_producer_state, intra1_acc_pipeline, C + ): + peek_wr_intra1_acc_empty_status = cutlass.Boolean(1) + if intra1_acc_producer_state.count < C: + peek_wr_intra1_acc_empty_status = intra1_acc_pipeline.producer_try_acquire( + intra1_acc_producer_state + ) + return peek_wr_intra1_acc_empty_status + + def pre_intra_tmem_load_and_partition_q(self, tIntra1, local_tidx): + copy_atom_t2r_intra1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(16), tcgen05.Pack.NONE), + self.acc_dtype, + ) + # (L, L) + fake_sQ = cute.make_tensor( + cute.make_ptr(self.io_dtype, 0, cute.AddressSpace.smem), + cute.dice(self.tile_shape_mnk_intra1, (1, 1, None)), + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_intra1, tIntra1, (None, None, 0), local_tidx, fake_sQ + ) + + def pre_intra_make_delta(self, smem_delta, extend_on_row_or_col): + smem_iterator = smem_delta.iterator + delta_linear_smem_layout = smem_delta.layout + # extend L linear layout to LxL + extend_layout = cute.make_layout(delta_linear_smem_layout.shape[0], stride=0) + if extend_on_row_or_col == 0: + # (L, L, INPUT_STAGE):(0, 1, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.prepend( + delta_linear_smem_layout, + extend_layout, + ), + ) + else: + # (L, L, INPUT_STAGE):(1, 0, L) + sDelta = cute.make_tensor( + smem_iterator, + cute.append( + cute.append( + cute.get(delta_linear_smem_layout, mode=[0]), + extend_layout, + ), + cute.get(delta_linear_smem_layout, mode=[1]), + ), + ) + return sDelta + + def pre_intra_tmem_store_and_partition_q(self, local_tidx, tCrQ): + dtype = tCrQ.element_type + # Make tiledCopy for tensor memory store INTRA2_Q + copy_atom_r2t_q = cute.make_copy_atom( + tcgen05.St16x128bOp(tcgen05.Repetition(16), tcgen05.Unpack.NONE), + dtype, + ) + tiled_r2t_q = tcgen05.make_tmem_copy(copy_atom_r2t_q, tCrQ) + thr_r2t_q = tiled_r2t_q.get_slice(local_tidx) + + # Partition tmem/register tensor for tensor memory store INTRA2_Q + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ...) + tRT_rQ = cute.make_fragment( + cute.slice_(thr_r2t_q.partition_S(tCrQ).shape, (None, None, None, None, 0)), + dtype, + ) + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N, ..., INTERNAL_STAGE) + tRT_tQ = thr_r2t_q.partition_D(tCrQ) + + return tiled_r2t_q, tRT_rQ, tRT_tQ + + @cute.jit + def pre_intra_segsum( + self, tTR_rQ, tQrDeltaA_Row, tQrDeltaA_Col, tQrDelta, tCoord, tRT_rQ + ): + # Make tmp acc type fragments + tCrDeltaA_Row = cute.make_fragment(tQrDeltaA_Row.shape, self.acc_dtype) + tCrDeltaA_Col = cute.make_fragment(tQrDeltaA_Col.shape, self.acc_dtype) + tCrDelta = cute.make_fragment(tQrDelta.shape, self.acc_dtype) + tCompute = cute.make_fragment(tRT_rQ.shape, self.acc_dtype) + + # Combine tTR_rQ/tCrDeltaA_Row/tCrDeltaA_Col/tCrDelta + tCrDeltaA_Row.store(tQrDeltaA_Row.load().to(self.acc_dtype)) + tCrDeltaA_Col.store(tQrDeltaA_Col.load().to(self.acc_dtype)) + tCrDelta.store(tQrDelta.load().to(self.acc_dtype)) + + # SegSum + # fadd2 + fsel + fmul2/mufu + fmul2 + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.add_packed_f32x2( + (tCrDeltaA_Col[subtile_idx], tCrDeltaA_Col[subtile_idx + 1]), + (-tCrDeltaA_Row[subtile_idx], -tCrDeltaA_Row[subtile_idx + 1]), + ) + for subtile_idx in cutlass.range(cute.size(tTR_rQ), unroll_full=True): + m, n = tCoord[subtile_idx] + if m < n: + tCompute[subtile_idx] = cutlass.Float32(-float("inf")) + LOG2_E = cutlass.Float32(1.4426950408889634) + for subtile_idx in cutlass.range(0, cute.size(tTR_rQ), 2, unroll_full=True): + # TODO: use math.exp directly + tCompute_log2e = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), (LOG2_E, LOG2_E) + ) + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + ( + cute.math.exp2(tCompute_log2e[0], fastmath=True), + cute.math.exp2(tCompute_log2e[1], fastmath=True), + ), + (tCrDelta[subtile_idx], tCrDelta[subtile_idx + 1]), + ) + ( + tCompute[subtile_idx], + tCompute[subtile_idx + 1], + ) = cute.arch.mul_packed_f32x2( + (tCompute[subtile_idx], tCompute[subtile_idx + 1]), + (tTR_rQ[subtile_idx], tTR_rQ[subtile_idx + 1]), + ) + + tRT_rQ.store(tCompute.load().to(self.io_dtype)) + return tRT_rQ + + def pre_inter_smem_load_and_partition_b(self, local_tidx, smem_bt): + dtype = smem_bt.element_type + copy_atom_s2r_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + num_elements_per_thread = 128 // dtype.width + num_threads_per_row = self.tile_shape_mnk_inter1[2] // num_elements_per_thread + num_threads_per_col = 128 // num_threads_per_row + thread_layout = cute.make_layout( + (num_threads_per_col, num_threads_per_row), + stride=(num_threads_per_row, 1), + ) + val_layout = cute.make_layout((1, num_elements_per_thread)) + tiled_s2r_b = cute.make_tiled_copy_tv( + copy_atom_s2r_b, + thread_layout, + val_layout, + ) + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + + # Partition shared tensor for smem load Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsB_s2r = thr_s2r_b.partition_S(smem_bt) + + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_s2r = cute.make_fragment( + cute.slice_(tBsB_s2r.shape, (None, None, None, 0)), + dtype, + ) + return tiled_s2r_b, tBsB_s2r, tBrB_s2r + + def pre_inter_smem_store_and_partition_b( + self, local_tidx, smem_bt_internal, tiled_s2r_b, tBrB_s2r + ): + dtype = smem_bt_internal.element_type + # Make tiledCopy from register to smem store Bt + copy_atom_r2s_b = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + dtype, + num_bits_per_copy=128, + ) + tiled_r2s_b = cute.make_tiled_copy_S(copy_atom_r2s_b, tiled_s2r_b) + thr_r2s_b = tiled_r2s_b.get_slice(local_tidx) + + # Partition shared tensor for smem store Bt + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tBsB_r2s = thr_r2s_b.partition_D(smem_bt_internal) + + # Make register fragments for smem load/store Bt + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrB_r2s = thr_r2s_b.retile(tBrB_s2r) + return tiled_r2s_b, tBrB_r2s, tBsB_r2s + + def smem_load_and_partition_delta_d( + self, tiled_s2r_b, local_tidx, smem_delta, smem_tile_coord + ): + dtype = smem_delta.element_type + s2r_atom_delta = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), dtype) + + thr_s2r_b = tiled_s2r_b.get_slice(local_tidx) + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N, INPUT_STAGE) + tBsDelta_s2r = thr_s2r_b.partition_D(smem_delta) + + # Make register fragments for smem load/store of Delta/DeltaA + # ((S2R_ATOM_V, S2R_REST_V), S2R_M, S2R_N) + tBrDelta_s2r = cute.make_fragment(tBsDelta_s2r[smem_tile_coord].shape, dtype) + return s2r_atom_delta, tBsDelta_s2r, tBrDelta_s2r + + def pre_inter_tmem_load_and_partition_p(self, local_tidx, tInter1, smem_pt): + copy_atom_t2r_inter1 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(8), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter1, + tInter1, + (None, None, 0), + local_tidx, + smem_pt[None, None, 0], + ) + + def make_tmem_load_and_partition( + self, copy_atom_t2r, tmem_tensor, tmem_tile_coord, local_tidx, smem_tensor + ): + dtype = tmem_tensor.element_type + tiled_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tmem_tensor[tmem_tile_coord]) + thr_t2r = tiled_t2r.get_slice(local_tidx) + # Partition tmem/shared tensor for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_t = thr_t2r.partition_S(tmem_tensor) + tTR_s = thr_t2r.partition_D(smem_tensor) + # Make register fragments for tmem load INTER1_ACC + # ((T2R_ATOM_V, T2R_REST_V), T2R_M, T2R_N) + tTR_r = cute.make_fragment( + tTR_s.shape, + dtype, + ) + return tiled_t2r, tTR_t, tTR_r + + def smem_store_and_partition_p_y(self, local_tidx, smem_pt, tiled_t2r_inter1): + dtype = smem_pt.element_type + copy_atom_r2s_p = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_r2s_p = cute.make_tiled_copy_D(copy_atom_r2s_p, tiled_t2r_inter1) + thr_r2s_p = tiled_r2s_p.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N, INTERNAL_STAGE) + tRS_sP = thr_r2s_p.partition_D(smem_pt) + # ((R2S_ATOM_V, R2S_REST_V), R2S_M, R2S_N) + tRS_rP = cute.make_fragment( + cute.slice_(tRS_sP.shape, (None, None, None, 0)), self.io_dtype + ) + return tiled_r2s_p, tRS_rP, tRS_sP + + def pre_inter_make_delta(self, smem_delta, smem_bt_layout): + # Broadcast Delta/DeltaA to Bt shape on M dimension + # before: (128,(64,2),2):(64,(1,8192),16384) + # after : (128,(64,2),2):(0,(1,64),128) + # (MMA, MMA_M, MMA_K, INPUT_STAGE) + sDeltaA = cute.make_tensor( + smem_delta.iterator, + cute.make_layout( + smem_bt_layout.shape, + stride=( + 0, + (1, cute.get(smem_bt_layout.shape, mode=[1, 0])), + smem_delta.layout.stride[1], + ), + ), + ) + return sDeltaA + + def pre_inter_scale_bt_with_delta( + self, tBrB_s2r, tBrDelta_s2r, tBrDeltaA_s2r, last_column + ): + tCompute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrB_Compute = cute.make_fragment(tBrB_s2r.shape, self.acc_dtype) + tBrDelta_Compute = cute.make_fragment(tBrDelta_s2r.shape, self.acc_dtype) + tBrDeltaA_Compute = cute.make_fragment(tBrDeltaA_s2r.shape, self.acc_dtype) + + tBrB_Compute.store(tBrB_s2r.load().to(self.acc_dtype)) + tBrDelta_Compute.store(tBrDelta_s2r.load().to(self.acc_dtype)) + tBrDeltaA_Compute.store(tBrDeltaA_s2r.load().to(self.acc_dtype)) + + for reg_idx in range(0, cute.size(tBrB_Compute), 2): + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + ( + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx]), fastmath=True + ), + cute.math.exp( + (last_column - tBrDeltaA_Compute[reg_idx + 1]), fastmath=True + ), + ), + (tBrDelta_Compute[reg_idx], tBrDelta_Compute[reg_idx + 1]), + ) + tCompute[reg_idx], tCompute[reg_idx + 1] = cute.arch.mul_packed_f32x2( + (tCompute[reg_idx], tCompute[reg_idx + 1]), + (tBrB_Compute[reg_idx], tBrB_Compute[reg_idx + 1]), + ) + return tCompute + + def epilog_make_delta(self, smem_cumsum_delta): + # Broadcast cumsum delta from LxINPUT_STAGE to LxDxINPUT_STAGE + sDeltaA = cute.make_tensor( + smem_cumsum_delta.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(1, 0, smem_cumsum_delta.layout.shape[0]), + ), + ) + return sDeltaA + + def epilog_make_d(self, smem_d): + # Broadcast d from DxINPUT_STAGE to LxDxINPUT_STAGE + sD = cute.make_tensor( + smem_d.iterator, + cute.make_layout( + (*self.tile_shape_mnk_inter2[:2], self.input_stages), + stride=(0, 1, smem_d.layout.shape[0]), + ), + ) + return sD + + def epilog_tma_partition_y(self, tma_tensor_y, tma_atom_y, smem_y, epi_tile): + # Local_tile partition global tensors + # (L, D, 1, 1, C, EH, B) + gY = cute.local_tile( + tma_tensor_y, + cute.slice_(self.tile_shape_mnk_inter2, (None, None, 0)), + (None, None, None, None, None), + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, 1, 1, C, EH, B) + gY_epi = cute.flat_divide(gY, epi_tile) + # ((ATOM_V, REST_V), INPUT_STAGE) + # ((ATOM_V, REST_V), EPI_M, EPI_N, 1, 1, C, EH, B) + bSG_sY, bSG_gY_pre_slice = cpasync.tma_partition( + tma_atom_y, + 0, + cute.make_layout(1), + cute.group_modes(smem_y, 0, 2), + cute.group_modes(gY_epi, 0, 2), + ) + return bSG_sY, bSG_gY_pre_slice + + def epilog_smem_load_and_partition_x( + self, tiled_t2r_inter2_intra2, local_tidx, smem_xt, epi_tile + ): + dtype = smem_xt.element_type + copy_atom_s2r_x = cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=True, num_matrices=4), + dtype, + ) + tiled_s2r_x = cute.make_tiled_copy_D(copy_atom_s2r_x, tiled_t2r_inter2_intra2) + thr_s2r_x = tiled_s2r_x.get_slice(local_tidx) + # Partition smem/register tensor for smem store INTER2_P + # (R2S_ATOM, R2S_M, R2S_N, EPI_M, EPI_N, INPUT_STAGES) + tSR_sX = thr_s2r_x.partition_S(cute.flat_divide(smem_xt, epi_tile)) + # (R2S_ATOM, R2S_M, R2S_N) + tSR_rX = cute.make_fragment( + cute.slice_(tSR_sX.shape, (None, None, None, 0, 0, 0)), dtype + ) + return tiled_s2r_x, tSR_sX, tSR_rX + + def epilog_tmem_load_and_partition_acc(self, local_tidx, tIntra, smem_y): + copy_atom_t2r_inter2_intra2 = cute.make_copy_atom( + tcgen05.Ld16x256bOp(tcgen05.Repetition(4), tcgen05.Pack.NONE), + self.acc_dtype, + ) + return self.make_tmem_load_and_partition( + copy_atom_t2r_inter2_intra2, + tIntra, + (None, None, 0, 0, 0), + local_tidx, + smem_y[None, None, 0], + ) + + +def run( + gbehcdln: Tuple[int, int, int, int, int, int, int, int], + io_dtype: Type[cutlass.Numeric], + cumsum_delta_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + fuse_scale_d: str, + tolerance: float, + print_rtol_stats: bool, + ref_lower_precision: bool, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + has_d = fuse_scale_d != "none" + d_has_hdim = fuse_scale_d == "vector" + + print(f"Running B100 Mamba2 SSD with:") + print(f"GBEHCDLN: {gbehcdln}") + print( + f"Input/Output dtype: {io_dtype}, Intermediate delta dtype: {cumsum_delta_dtype}, Acc dtype: {acc_dtype}" + ) + print( + f"Has D (True means fuse Y+=X*D): {has_d}, D has Hdim (True means D.shape DxEH, False means 1xEH): {d_has_hdim}" + ) + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {'True' if use_cold_l2 else 'False'}") + + # Unpack parameters + G, B, E, H, C, D, L, N = gbehcdln + EH = E * H + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + # Match same seed in ssd_reference.py for reference check + torch.manual_seed(42) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + shape, permute_order, dtype, dt_or_a=0, dynamic_modes=None, ref_tensor=None + ): + # Build fp32 reference torch tensor + if ref_tensor is None: + ref_tensor = ( + torch.empty(*shape, dtype=torch.float32) + # .random_(-1, 1) + .normal_(0, 0.5) + # .uniform_(-1,1) + .permute(permute_order) + ) + if dt_or_a == 1: # dt: + ref_tensor = F.softplus(ref_tensor - 4) + elif dt_or_a == 2: # A: + ref_tensor = -torch.exp(ref_tensor) + + # Build torch_dtype torch tensor + torch_dtype = cutlass_torch.dtype(dtype) + + dst_tensor = ref_tensor.to(torch_dtype).cuda() + cute_tensor = from_dlpack(dst_tensor, assumed_align=16) + for mode in dynamic_modes: + cute_tensor = cute_tensor.mark_compact_shape_dynamic( + mode=mode, stride_order=dst_tensor.dim_order() + ) + + return ref_tensor, cute_tensor, dst_tensor + + # INPUT tensors + # x: (D, L, C, EH, B):(C*L, 1, L, D*C*L, EH*D*C*L) + x_ref, x_tensor, x_torch = create_and_permute_tensor( + [B, EH, D, C, L], [2, 4, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # delta/delta_a/cumsum_delta: (L, C, EH, B):(1, L, C*L, EH*C*L) + delta_ref, delta_tensor, delta_torch = create_and_permute_tensor( + [B, EH, C, L], [3, 2, 1, 0], io_dtype, dt_or_a=1, dynamic_modes=[1, 2, 3] + ) + # a: (EH):(1) + a_ref, a_tensor, a_torch = create_and_permute_tensor( + [EH], [0], io_dtype, dt_or_a=2, dynamic_modes=[0] + ) + + if has_d: + # d: (D, EH):(1, D) or (1, EH):(0, 1) + d_ref, d_tensor, d_torch = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], [1, 0], io_dtype, dynamic_modes=[1] + ) + else: + d_ref = None + d_tensor = None + + # b/c: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + b_ref, b_tensor, b_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + c_ref, c_tensor, c_torch = create_and_permute_tensor( + [B, G, N, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + + # OUTPUT tensors + # y: (L, D, C, EH, B):(1, C*L, L, D*C*L, EH*D*C*L) + y_ref, y_tensor, y_torch = create_and_permute_tensor( + [B, EH, D, C, L], [4, 2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3, 4] + ) + # fstate: (D, N, EH, B):(N, 1, D*N, EH*D*N) + fstate_ref, fstate_tensor, fstate_torch = create_and_permute_tensor( + [B, EH, D, N], [2, 3, 1, 0], io_dtype, dynamic_modes=[2, 3] + ) + + # Call pytorch reference on cpu + if not ref_lower_precision: + ssd_reference_fp32_all( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + d_ref, + has_d, + d_has_hdim, + ) + else: + ssd_reference_lowprecision_intermediates( + x_ref, + a_ref, + delta_ref, + b_ref, + c_ref, + y_ref, + fstate_ref, + cutlass_torch.dtype(io_dtype), + d_ref, + has_d, + d_has_hdim, + ) + + # Compute cumsum with pytorch on cpu + delta_a_ref = delta_ref * a_ref.view(1, 1, -1, 1) + cumsum_delta_ref = torch.empty([B, EH, C, L], dtype=torch.float32).permute( + [3, 2, 1, 0] + ) + cumsum_delta_ref.copy_(torch.cumsum(delta_a_ref, dim=0).permute([0, 1, 2, 3])) + # Copy cumsum_delta_ref to cumsum_delta_tensor + ( + cumsum_delta_ref, + cumsum_delta_tensor, + cumsum_delta_torch, + ) = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + + # Call fused ssd kernel + ssd = SSDKernel( + io_dtype, + cumsum_delta_dtype, + acc_dtype, + L, + D, + N, + has_d, + d_has_hdim, + ) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters(1) + + stream = cutlass.cuda.default_stream() + # Compile ssd kernel + compiled_ssd = cute.compile( + ssd, + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + max_active_clusters, + stream, + ) + + # Launch compiled ssd kernel for reference check + if not skip_ref_check: + compiled_ssd( + x_tensor, + cumsum_delta_tensor, + delta_tensor, + b_tensor, + c_tensor, + y_tensor, + fstate_tensor, + d_tensor, + stream, + ) + + # Reference check + if print_rtol_stats: + print("\nY's Relative diffs:") + analyze_relative_diffs( + y_torch.cpu(), y_ref.to(cutlass_torch.dtype(io_dtype)) + ) + print("\nFstate's Relative diffs:") + analyze_relative_diffs( + fstate_torch.cpu(), fstate_ref.to(cutlass_torch.dtype(io_dtype)) + ) + torch.testing.assert_close( + y_torch.cpu(), + y_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-02, + ) + torch.testing.assert_close( + fstate_torch.cpu(), + fstate_ref.to(cutlass_torch.dtype(io_dtype)), + atol=tolerance, + rtol=1e-05, + ) + + def generate_tensors(): + # Reuse existing CPU reference tensors and create new GPU tensors from them + _, x_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [2, 4, 3, 1, 0], + io_dtype, + ref_tensor=x_ref, + dynamic_modes=[2, 3, 4], + ) + _, cumsum_delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + cumsum_delta_dtype, + ref_tensor=cumsum_delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, delta_tensor_new, _ = create_and_permute_tensor( + [B, EH, C, L], + [3, 2, 1, 0], + io_dtype, + ref_tensor=delta_ref, + dynamic_modes=[1, 2, 3], + ) + _, b_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=b_ref, + dynamic_modes=[2, 3, 4], + ) + _, c_tensor_new, _ = create_and_permute_tensor( + [B, G, N, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=c_ref, + dynamic_modes=[2, 3, 4], + ) + _, y_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, C, L], + [4, 2, 3, 1, 0], + io_dtype, + ref_tensor=y_ref, + dynamic_modes=[2, 3, 4], + ) + _, fstate_tensor_new, _ = create_and_permute_tensor( + [B, EH, D, N], + [2, 3, 1, 0], + io_dtype, + ref_tensor=fstate_ref, + dynamic_modes=[2, 3], + ) + + if has_d: + _, d_tensor_new, _ = create_and_permute_tensor( + [EH, D if d_has_hdim else 1], + [1, 0], + io_dtype, + ref_tensor=d_ref, + dynamic_modes=[1], + ) + else: + d_tensor_new = d_tensor + + return testing.JitArguments( + x_tensor_new, + cumsum_delta_tensor_new, + delta_tensor_new, + b_tensor_new, + c_tensor_new, + y_tensor_new, + fstate_tensor_new, + d_tensor_new, + stream, + ) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + x_torch.numel() * x_torch.element_size() + + cumsum_delta_torch.numel() * cumsum_delta_torch.element_size() + + delta_torch.numel() * delta_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + + y_torch.numel() * y_torch.element_size() + + fstate_torch.numel() * fstate_torch.element_size() + ) + if has_d: + one_workspace_bytes += d_torch.numel() * d_torch.element_size() + + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_ssd, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> List[int]: + try: + return [int(x.strip()) for x in s.split(",")] + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + parser = argparse.ArgumentParser( + description="Example of MxNxKxL GEMM on Blackwell." + ) + + parser.add_argument( + "--gbehcdln", + type=parse_comma_separated_ints, + default=[2, 4, 2, 40, 32, 64, 128, 128], + # default=[2, 3, 2, 2, 8, 64, 128, 128], + # default=[1, 2, 1, 4, 8, 64, 128, 128], + help="gbehcdln dimensions (comma-separated)", + ) + parser.add_argument("--io_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument( + "--cumsum_delta_dtype", type=cutlass.dtype, default=cutlass.Float32 + ) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=cutlass.Float32) + parser.add_argument( + "--fuse_scale_d", + type=str, + choices=["none", "scalar", "vector"], + default="vector", + help="Fuse scale type: none (no Y+=X*D fusion), scalar (Y+=X*D fusion with D.shape=1xEH), or vector (Y+=X*D fusion with D.shape=DxEH)", + ) + parser.add_argument( + "--ref_lower_precision", + action="store_true", + default=True, + help="Use lower precision for reference check", + ) + parser.add_argument( + "--no-ref_lower_precision", + action="store_false", + dest="ref_lower_precision", + default=False, + help="Disable lower precision for reference check", + ) + parser.add_argument( + "--tolerance", type=float, default=5e-02, help="Tolerance for validation" + ) + parser.add_argument( + "--print_rtol_stats", + action="store_true", + default=True, + help="Enable print rtol stats", + ) + parser.add_argument( + "--no-print_rtol_stats", + action="store_false", + dest="print_rtol_stats", + default=False, + help="Disable print rtol stats", + ) + parser.add_argument( + "--warmup_iterations", + type=int, + default=0, + help="Number of warmup iterations", + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.gbehcdln) != 8: + parser.error("--gbehcdln must contain exactly 8 values") + + run( + args.gbehcdln, + args.io_dtype, + args.cumsum_delta_dtype, + args.acc_dtype, + args.fuse_scale_d, + args.tolerance, + args.print_rtol_stats, + args.ref_lower_precision, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py new file mode 100644 index 0000000000000000000000000000000000000000..ced5c6f2f51fb395b4efac96ee2b655710fca463 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_reference.py @@ -0,0 +1,397 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import torch +import torch.nn.functional as F + + +def ssd_reference_fp32_all(x, a, delta, B, C, Y_out, Fstate_out, D, has_d, d_has_hdim): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then directly call TriDao's ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + X = x * delta.unsqueeze(0) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # X: p l c h b -> b c l h p + X = X.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # X/A/B/C: b c l ... -> b (c l) ... + X, A, B, C = [x.reshape(x.shape[0], -1, *x.shape[3:]) for x in (X, A, B, C)] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = X.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + ################################################################### + # Call reference implementation from Tri Dao ssd_minimal_discrete + Y, final_state = ssd_minimal_discrete_fp32_all( + X, A, B, C, block_len, initial_states + ) + ################################################################### + + if has_d: + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", X, D) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def ssd_reference_lowprecision_intermediates( + x, a, delta, B, C, Y_out, Fstate_out, intermediate_dtype, D, has_d, d_has_hdim +): + """ + Rearrange tensor dimensions from cuda layout to reference layout, then call a reduced intermediate dtype version of ssd implementation + Arguments: + X/x: (D, L, C, H, B):(C*L, 1, L, D*C*L, H*D*C*L) + A/delta: (L, C, H, B):(1, L, C*L, H*C*L) + a: (H):(1) + B/C: (L, N, C, G, B):(1, C*L, L, N*C*L, G*N*C*L) + intermediate_dtype: input and intermediate data type + D: (1, H):(0, 1) or (D, H):(1, D) + has_d: bool + d_has_hdim: bool + Return: + Y_out: (L, D, C, H, B):(1, C*L, L, D*C*L, H*D*C*L) + Fstate_out: (D, N, H, B):(N, 1, D*N, H*D*N) + """ + assert x.dtype == a.dtype == delta.dtype == B.dtype == C.dtype + + A = delta * a.view(1, 1, -1, 1) + + # Rearrange to match cutlass layout to tridao's layout + block_len = A.shape[0] + initial_states = None + # A: l c h b-> b c l h + A = A.permute(3, 1, 0, 2) + # delta: l c h b-> b c l h + delta = delta.permute(3, 1, 0, 2) + # x: p l c h b -> b c l h p + x = x.permute(4, 2, 1, 3, 0) + # B: l n c g b -> b c l g n + B = B.permute(4, 2, 0, 3, 1) + # C: l n c g b -> b c l g n + C = C.permute(4, 2, 0, 3, 1) + # x/A/delta/B/C: b c l ... -> b (c l) ... + x, A, delta, B, C = [ + tensor.reshape(tensor.shape[0], -1, *tensor.shape[3:]) + for tensor in (x, A, delta, B, C) + ] + + # Ngroup (g to h) mapping + B_val, CL_val, G_val, N_val = B.shape + H_val = x.shape[2] + ngroup_ratio = H_val // G_val + # B/C: (B, CL, H, N) + h_to_g_mapping = torch.arange(H_val, device=B.device) // ngroup_ratio + B = B.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + C = C.gather(2, h_to_g_mapping.view(1, 1, -1, 1).expand(B_val, CL_val, -1, N_val)) + + # Type convert input tensors to input dtype (same as intermediate dtype) + x = x.to(intermediate_dtype).to(torch.float32) + A = A.to(intermediate_dtype).to(torch.float32) + delta = delta.to(intermediate_dtype).to(torch.float32) + B = B.to(intermediate_dtype).to(torch.float32) + C = C.to(intermediate_dtype).to(torch.float32) + + ######################################################################### + # Call reference implementation ssd_minimal_discrete_bf16_intermediates + Y, final_state = ssd_minimal_discrete_lowprecision_intermediates( + x, A, delta, B, C, block_len, intermediate_dtype, initial_states + ) + ######################################################################### + + if has_d: + D = D.to(intermediate_dtype).to(torch.float32) + D_val = Y.shape[3] + if not d_has_hdim: + D = D.expand(D_val, -1) + Y = Y + torch.einsum("bchp,ph->bchp", x, D) + + # Type convert output tensors to output dtype (same as intermediate dtype) + Y = Y.to(intermediate_dtype).to(torch.float32) + final_state = final_state.to(intermediate_dtype).to(torch.float32) + + # Rearrange to match tridao's layout to cutlass layout + # Y: b (c l) h p -> b c l h p + Y = Y.reshape(Y.shape[0], -1, block_len, Y.shape[2], Y.shape[3]) + # Y: b c l h p -> l p c h b + Y = Y.permute(2, 4, 1, 3, 0) + # Fstate_out: b h p n -> p n h b + Fstate_out.copy_(final_state.permute(2, 3, 1, 0)) + Y_out.copy_(Y) + return + + +def analyze_relative_diffs(actual, expected): + """ + Print statistics of relative differences between actual and expected tensors + """ + # Calculate relative differences + abs_diff = (actual - expected).abs() + rel_diff = abs_diff / (torch.maximum(expected.abs(), actual.abs()) + 0.00001) + + total_elements = rel_diff.numel() + + # Handle special cases first + nan_mask = torch.isnan(rel_diff) + inf_mask = torch.isinf(rel_diff) + nan_count = nan_mask.sum().item() + inf_count = inf_mask.sum().item() + + # Find position and value of maximum relative difference + max_rel_diff = ( + rel_diff[~nan_mask & ~inf_mask].max() + if (~nan_mask & ~inf_mask).any() + else float("nan") + ) + max_rel_diff_pos = ( + rel_diff[~nan_mask & ~inf_mask].argmax() + if (~nan_mask & ~inf_mask).any() + else -1 + ) + + # Print max relative difference info + print(f"Maximum relative difference:") + print(f"Position: {max_rel_diff_pos}") + print(f"Value: {max_rel_diff:.6e}") + print(f"Actual value: {actual.flatten()[max_rel_diff_pos]}") + print(f"Expected value: {expected.flatten()[max_rel_diff_pos]}") + print(f"NaN values: {nan_count} ({100.0 * nan_count / total_elements:.2f}%)") + print(f"Inf values: {inf_count} ({100.0 * inf_count / total_elements:.2f}%)\n") + + # Check different rtol thresholds + rtol_levels = [1e-5, 1e-4, 1e-3, 1e-2, 5e-02, 1e-01] + + for i, rtol in enumerate(rtol_levels): + if i == 0: + mask = rel_diff <= rtol + else: + mask = (rel_diff <= rtol) & (rel_diff > rtol_levels[i - 1]) + + count = mask.sum().item() + percentage = (count / total_elements) * 100 + + if i == 0: + print(f"Elements with rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)") + else: + print( + f"Elements with {rtol_levels[i-1]:.0e} < rtol <= {rtol:.0e}: {count} ({percentage:.2f}%)" + ) + + # Print elements exceeding the largest rtol + mask = rel_diff > rtol_levels[-1] + count = mask.sum().item() + percentage = (count / total_elements) * 100 + print(f"Elements with rtol > {rtol_levels[-1]:.0e}: {count} ({percentage:.2f}%)\n") + + +def segsum(x): + """ + More stable segment sum calculation. + x: b h c l + """ + T = x.size(-1) + # x: b h c l -> b h c l l + x = x.unsqueeze(-1).expand(*x.shape, T) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1) + x = x.masked_fill(~mask, 0) + x_segsum = torch.cumsum(x, dim=-2) + mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0) + x_segsum = x_segsum.masked_fill(~mask, -torch.inf) + return x_segsum + + +def ssd_minimal_discrete_fp32_all(X, A, B, C, block_len, initial_states=None): + """ + This is same with https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/ssd_minimal.py + (all accumulation and intermediate results in fp32) + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/B/C:b (c l) ... -> b c l ... + X, A, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off = torch.einsum("bclhn,bchpn,bhcl->bclhp", C, states, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape(Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4]) + return Y, final_state + + +def ssd_minimal_discrete_lowprecision_intermediates( + X, A, delta, B, C, block_len, intermediate_dtype, initial_states=None +): + """ + This is adjusted from ssd_minimal_discrete_fp32_all, with exceptions: + 1. accumulation in fp32 but intermediates Q/b_tmem/P are in intermediate_dtype + 2. delta is not pre-multiplied with X, delta was applied to generate Q/b_tmem to match GPU implementation + + Arguments: + X: (batch(B), length(C*L), n_heads(H), d_head(D)) + A: (batch(B), length(C*L), n_heads(H)) + delta: (batch(B), length(C*L), n_heads(H)) + B: (batch(B), length(C*L), n_heads(H), d_state(N)) + C: (batch(B), length(C*L), n_heads(H), d_state(N)) + Return: + Y: (batch(B), length(C*L), n_heads(H), d_head(D)) + final_state: (B, H, D, N) + """ + assert X.dtype == A.dtype == B.dtype == C.dtype + assert X.shape[1] % block_len == 0 + + # Rearrange into blocks/chunks + # X/A/delta/B/C: b (c l) ... -> b c l ... + X, A, delta, B, C = [ + x.reshape(x.shape[0], -1, block_len, *x.shape[2:]) for x in (X, A, delta, B, C) + ] + + # A: b c l h -> b h c l + A = A.permute(0, 3, 1, 2) + # delta: b c l h -> b h c l + delta = delta.permute(0, 3, 1, 2) + # A_cumsum: (B, H, C, L) + A_cumsum = torch.cumsum(A, dim=-1) + + # 1. Compute the output for each intra-chunk (diagonal blocks) + segsum_A = segsum(A) + L = torch.exp(segsum_A) + intra_acc_0 = torch.einsum("bclhn,bcshn->bclhs", C, B) + Q = torch.einsum("bclhs,bhcls,bhcs->bclhs", intra_acc_0, L, delta) + Y_diag = torch.einsum( + "bclhs,bcshp->bclhp", Q.to(intermediate_dtype).to(torch.float32), X + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum)) + b_tmem = torch.einsum("bclhn,bhcl,bhcl->bclhn", B, decay_states, delta) + states = torch.einsum( + "bclhn,bclhp->bchpn", b_tmem.to(intermediate_dtype).to(torch.float32), X + ) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + if initial_states is None: + initial_states = torch.zeros_like(states[:, :1]) + states = torch.cat([initial_states, states], dim=1) + decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) + new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) + states, final_state = new_states[:, :-1], new_states[:, -1] + final_state = final_state + + # 4. Compute state -> output conversion per chunk + # (left term of low-rank factorization of off-diagonal blocks; C terms) + state_decay_out = torch.exp(A_cumsum) + Y_off_tmp = torch.einsum( + "bclhn,bchpn->bclhp", C, states.to(intermediate_dtype).to(torch.float32) + ) + Y_off = torch.einsum("bclhp,bhcl->bclhp", Y_off_tmp, state_decay_out) + + # Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks) + # Y: b c l h p -> b (c l) h p + Y = (Y_diag + Y_off).reshape( + Y_diag.shape[0], -1, Y_diag.shape[3], Y_diag.shape[4] + ) # b (c l) h p + return Y, final_state diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..544b477228132922c38c9855d44ad85d94f2effb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/blackwell/mamba2_ssd/mamba2_ssd_tile_scheduler.py @@ -0,0 +1,200 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +from typing import Tuple + +from cutlass.cutlass_dsl import ( + Boolean, + Integer, + Int32, + min, + extract_mlir_values, + new_from_mlir_values, + dsl_user_op, +) +from cutlass._mlir import ir +import cutlass.cute as cute +from cutlass.utils import WorkTileInfo + + +class Mamba2SSDTileSchedulerParams: + def __init__( + self, + problem_shape_ntiles: int, + eh: int, + ngroup_ratio: int, + *, + loc=None, + ip=None, + ): + self.problem_shape_ntiles = problem_shape_ntiles + self.eh = eh + self.ngroup_ratio = ngroup_ratio + self._loc = loc + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [self.problem_shape_ntiles, self.eh, self.ngroup_ratio]: + obj_values = extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [self.problem_shape_ntiles, self.eh, self.ngroup_ratio], self._values_pos + ): + obj_list.append(new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return Mamba2SSDTileSchedulerParams(*(tuple(obj_list)), loc=self._loc) + + @dsl_user_op + def get_grid_shape( + self, max_active_clusters: Int32, *, loc=None, ip=None + ) -> Tuple[Integer, Integer, Integer]: + return (min(self.problem_shape_ntiles, max_active_clusters), 1, 1) + + +class Mamba2SSDTileScheduler: + def __init__( + self, + params: Mamba2SSDTileSchedulerParams, + num_persistent_ctas: Int32, + current_work_linear_idx: Int32, + num_tiles_executed: Int32, + ): + self.params = params + self.num_persistent_ctas = num_persistent_ctas + self._current_work_linear_idx = current_work_linear_idx + self._num_tiles_executed = num_tiles_executed + + def __extract_mlir_values__(self) -> list[ir.Value]: + values = extract_mlir_values(self.num_persistent_ctas) + values.extend(extract_mlir_values(self._current_work_linear_idx)) + values.extend(extract_mlir_values(self._num_tiles_executed)) + return values + + def __new_from_mlir_values__( + self, values: list[ir.Value] + ) -> "Mamba2SSDTileScheduler": + assert len(values) == 3 + new_num_persistent_ctas = new_from_mlir_values( + self.num_persistent_ctas, [values[0]] + ) + new_current_work_linear_idx = new_from_mlir_values( + self._current_work_linear_idx, [values[1]] + ) + new_num_tiles_executed = new_from_mlir_values( + self._num_tiles_executed, [values[2]] + ) + return Mamba2SSDTileScheduler( + self.params, + new_num_persistent_ctas, + new_current_work_linear_idx, + new_num_tiles_executed, + ) + + # called by host + @dsl_user_op + @staticmethod + def create( + params: Mamba2SSDTileSchedulerParams, + block_idx: Tuple[Integer, Integer, Integer], + grid_dim: Tuple[Integer, Integer, Integer], + *, + loc=None, + ip=None, + ): + params = params + + # Calculate the number of persistent clusters by dividing the total grid size + # by the number of CTAs per cluster + num_persistent_ctas = Int32(cute.size(grid_dim, loc=loc, ip=ip)) + + bidx, bidy, bidz = block_idx + + # Initialize workload index equals to the cluster index in the grid + current_work_linear_idx = Int32(bidx) + + # Initialize number of tiles executed to zero + num_tiles_executed = Int32(0) + return Mamba2SSDTileScheduler( + params, + num_persistent_ctas, + current_work_linear_idx, + num_tiles_executed, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Mamba2SSDTileSchedulerParams, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Integer, Integer, Integer]: + return params.get_grid_shape(max_active_clusters, loc=loc, ip=ip) + + # private method + def _get_current_work_for_linear_idx( + self, current_work_linear_idx: Int32, *, loc=None, ip=None + ) -> WorkTileInfo: + is_valid = current_work_linear_idx < cute.size( + self.params.problem_shape_ntiles, loc=loc, ip=ip + ) + + eh_idx = current_work_linear_idx % self.params.eh + b_idx = current_work_linear_idx // self.params.eh + g_idx = eh_idx // self.params.ngroup_ratio + # cur_tile_coord is (b_idx, eh_idx, g_idx) + cur_tile_coord = tuple(Int32(x) for x in (b_idx, eh_idx, g_idx)) + + return WorkTileInfo(cur_tile_coord, is_valid) + + @dsl_user_op + def get_current_work(self, *, loc=None, ip=None) -> WorkTileInfo: + return self._get_current_work_for_linear_idx( + self._current_work_linear_idx, loc=loc, ip=ip + ) + + @dsl_user_op + def initial_work_tile_info(self, *, loc=None, ip=None) -> WorkTileInfo: + return self.get_current_work(loc=loc, ip=ip) + + @dsl_user_op + def advance_to_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + self._current_work_linear_idx += Int32(advance_count) * Int32( + self.num_persistent_ctas + ) + self._num_tiles_executed += Int32(1) + + @property + def num_tiles_executed(self) -> Int32: + return self._num_tiles_executed diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py new file mode 100644 index 0000000000000000000000000000000000000000..acdb42efb997a943fbf3c442a985eda1bfa20e24 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/cute/ffi/jit_argument.py @@ -0,0 +1,305 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Example of accessing POD (Plain Old Data) from C or other languages via LLVM operations. + +This example demonstrates a basic approach to building customized interfaces as C-structures between user code +and JIT compiled functions. It provides a minimal-cost solution for calling JIT functions +and can be used to build AOT (Ahead-of-Time) launchers for JIT compiled functions. + +The C-structure is defined as: + +.. code-block:: c + + struct Tensor { + void *ptr; // Pointer to tensor data + int32_t shape[3]; // Tensor dimensions + int32_t strides[3]; // Memory strides for each dimension + }; + +The example defines Tensor and TensorValue classes that wrap C structs for view of a tensor with its data pointer, +shape, and strides, enabling efficient data passing between different language boundaries. + +.. note:: + Future development may include automated code generation flows. +""" + +import cutlass +import cutlass.cute as cute + +from cutlass._mlir import ir +from cutlass._mlir.dialects import llvm +import cutlass._mlir.extras.types as T + + +class ExampleTensorValue(ir.Value): + """A wrapper class for tensor values in MLIR. + + This class extends ir.Value to provide convenient access to tensor data pointer, + shape, and strides through MLIR operations. + + :type: ir.Value + """ + + def __init__(self, v): + """Initialize a new TensorValue. + + :param v: The underlying MLIR value to wrap + :type v: ir.Value + """ + super().__init__(v) + + @property + def data_ptr(self, *, loc=None, ip=None): + """Get the data pointer from the tensor value. + + Extracts the data pointer (first field) from the LLVM struct value. + + :param loc: Optional location information for MLIR operations + :type loc: Optional[ir.Location] + :param ip: Optional insertion point for MLIR operations + :type ip: Optional[ir.InsertionPoint] + :return: An integer value representing the data pointer + :rtype: ir.Value + """ + # Extract the data pointer from the LLVM struct value + # The data pointer is the first field (index 0) in the struct + + # Use llvm.extractvalue to get the pointer field from the struct + ptr_val = llvm.extractvalue( + llvm.PointerType.get(), + self, + [0], # Extract the first field (index 0) + loc=loc, + ip=ip, + ) + + return cute.make_ptr(cutlass.Float32, ptr_val) + + @property + def shape(self): + """Get the shape of the tensor. + + Extracts the shape (second field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor dimensions + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the shape field from the LLVM struct value + # The shape is the second field (index 1) in the struct + shape_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [1], # Extract the second field (index 1) + ) + + # Extract each dimension from the shape struct + return tuple(llvm.extractvalue(i32_type, shape_val, [i]) for i in range(3)) + + @property + def stride(self): + """Get the strides of the tensor. + + Extracts the strides (third field) from the LLVM struct value. + + :return: A tuple of integers representing the tensor strides + :rtype: tuple[ir.Value, ...] + """ + i32_type = ir.IntegerType.get_signless(32) + # Extract the strides field from the LLVM struct value + # The strides are the third field (index 2) in the struct + strides_val = llvm.extractvalue( + llvm.StructType.get_literal([i32_type] * 3), + self, + [2], # Extract the third field (index 2) + ) + + # Extract each dimension from the strides struct + return tuple(llvm.extractvalue(i32_type, strides_val, [i]) for i in range(3)) + + +class ExampleTensor: + """A class representing a tensor with its data pointer, shape, and strides. + + This class provides a Python interface to create and manipulate tensor structures + that can be passed to CUTE JIT compiled functions. + + :ivar _c_struct_p: The C struct pointer for the tensor + :ivar _rank: The number of dimensions in the tensor + """ + + def __init__(self, c_struct_p, rank): + """Initialize a new Tensor. + + :param c_struct_p: The C struct pointer for the tensor + :type c_struct_p: int + :param rank: The number of dimensions in the tensor + :type rank: int + """ + self._c_struct_p = c_struct_p + self._rank = rank + + def __get_mlir_types__(self): + """Get the MLIR types for this tensor. + + Creates an LLVM structure type representing a C-structure with: + + .. code-block:: c + + struct Tensor { + void *ptr; + int32_t shape[3]; + int32_t strides[3]; + }; + + :return: A list containing the MLIR struct type + :rtype: list[llvm.StructType] + + Create an LLVM structure type that represents a C-structure like: + """ + + # Get the number of dimensions from the shape + ndim = self._rank + + # Create the pointer type (void*) + ptr_type = llvm.PointerType.get() + + # Create array types for shape and strides (int32_t[ndim]) + int32_type = ir.IntegerType.get_signless(32) + shape_type = llvm.StructType.get_literal([int32_type] * ndim) + strides_type = llvm.StructType.get_literal([int32_type] * ndim) + + # Create the structure type + struct_type = llvm.StructType.get_literal([ptr_type, shape_type, strides_type]) + + return [struct_type] + + def __new_from_mlir_values__(self, values): + """Create a new TensorValue from MLIR values. + + :param values: A list of MLIR values + :type values: list[ir.Value] + :return: A new TensorValue instance + :rtype: TensorValue + """ + return ExampleTensorValue(values[0]) + + def __c_pointers__(self): + """Get the C pointers for this tensor. + + :return: A list containing the C struct pointer + :rtype: list[int] + """ + return [self._c_struct_p] + + +@cute.jit +def foo(tensor): + """Example JIT function that prints tensor information. + + :param tensor: A Tensor instance to print information about + :type tensor: Tensor + """ + cute.printf("data_ptr: {}", tensor.data_ptr) + cute.printf("shape: {}", tensor.shape) + cute.printf("stride: {}", tensor.stride) + + mA = cute.make_tensor( + tensor.data_ptr, cute.make_layout(tensor.shape, stride=tensor.stride) + ) + cute.print_tensor(mA) + + +import sys +import os +import subprocess +import shutil +import tempfile +import torch + + +def run_test(tmpdir=None): + # Skip cleanup if user provides tmpdir + cleanup = tmpdir is None + # Initialize temporary build directory + tmpdir = tmpdir or tempfile.mkdtemp() + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + + subprocess.run(["cmake", "-B", tmpdir, current_dir], check=True) + subprocess.run(["cmake", "--build", tmpdir], check=True) + + sys.path.append(tmpdir) + + from tensor import make_tensor, pycapsule_get_pointer + + # Mock test tensor and corresponding C structure for this example + # In production, this may come from external library + x = torch.arange(2 * 8 * 4).to(torch.float32).reshape(2, 8, 4) + c_struct = make_tensor(x.data_ptr(), x.shape, x.stride()) + c_struct_p = pycapsule_get_pointer(c_struct) + + # Initialize tensor wrapper and compile test function + tensor = ExampleTensor(c_struct_p, len(x.shape)) + compiled_func = cute.compile(foo, tensor) + + # Benchmark pointer access performance + from time import time + + start = time() + # Measure performance of critical path pointer access + # get C pointers is on critical path to call JIT compiled function + for _ in range(1000): + tensor.__c_pointers__() + end = time() + print(f"__c_pointers__: {(end - start) * 1000} us") + + # Execute compiled function + compiled_func(tensor) + except Exception as e: + print(e) + finally: + if cleanup: + # Clean up the temporary directory + shutil.rmtree(tmpdir) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Set temporary directory for building C modules" + ) + parser.add_argument( + "--tmp-dir", type=str, help="Temporary directory path for building C modules" + ) + args = parser.parse_args() + + run_test(args.tmp_dir) diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py new file mode 100644 index 0000000000000000000000000000000000000000..c59ace02d6bf9e9d96e1fa04a62c745386998e44 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/examples/python/CuTeDSL/hopper/dense_gemm.py @@ -0,0 +1,1595 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +from typing import Tuple, Type +import math +import cuda.bindings.driver as cuda + +import torch + +import cutlass +import cutlass.cute as cute +import cutlass.cute.testing as testing +import cutlass.utils as utils +import cutlass.pipeline as pipeline +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack +import cutlass.utils.hopper_helpers as sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture +using CuTe DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Hopper WGMMA instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +To run this example: + +.. code-block:: bash + + python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +The above example command compute batched gemm with M=8192, N=8192, K=8192, +batch_count=1. The Hopper WGMMA tile shape is 128x256x64 and the cluster shape +is (1,1). The input, mma accumulator and output data type are set as fp16, fp32 +and fp16, respectively. + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/hopper/dense_gemm.py \ + --mnkl 8192,8192,8192,1 --tile_shape_mn 128,256 \ + --cluster_shape_mn 1,1 --a_dtype Float16 --b_dtype Float16 \ + --c_dtype Float16 --acc_dtype Float32 \ + --a_major k --b_major k --c_major n + +Constraints: +* Supported input data types: fp16, fp8 (e4m3fn, e5m2) +* For fp16 types, A and B must have the same data type +* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit +* Fp8 types only support k-major layout +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. +""" + + +# ///////////////////////////////////////////////////////////////////////////// +# Helpers to parse args +# ///////////////////////////////////////////////////////////////////////////// +def parse_comma_separated_ints(s: str): + try: + return tuple([int(x.strip()) for x in s.split(",")]) + except ValueError: + raise argparse.ArgumentTypeError( + "Invalid format. Expected comma-separated integers." + ) + + +def parse_arguments() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Example of MxNxKxL GEMM on Hopper.") + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(4096, 4096, 4096, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--tile_shape_mn", + type=parse_comma_separated_ints, + choices=[(128, 128), (128, 256), (128, 64), (64, 64)], + default=(128, 128), + help="Cta tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + choices=[(1, 1), (2, 1), (1, 2), (2, 2)], + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument( + "--a_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--b_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--c_dtype", + type=cutlass.dtype, + default=cutlass.Float16, + ) + parser.add_argument( + "--acc_dtype", + type=cutlass.dtype, + default=cutlass.Float32, + ) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + parser.add_argument( + "--tolerance", type=float, default=1e-01, help="Tolerance for validation" + ) + parser.add_argument( + "--warmup_iterations", type=int, default=0, help="Warmup iterations" + ) + parser.add_argument( + "--iterations", + type=int, + default=1, + help="Number of iterations to run the kernel", + ) + parser.add_argument( + "--skip_ref_check", action="store_true", help="Skip reference checking" + ) + parser.add_argument( + "--use_cold_l2", + action="store_true", + default=False, + help="Use circular buffer tensor sets to ensure L2 cold cache", + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + if len(args.tile_shape_mn) != 2: + parser.error("--tile_shape_mn must contain exactly 2 values") + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + return args + + +# ///////////////////////////////////////////////////////////////////////////// +# Host setup and device kernel launch +# ///////////////////////////////////////////////////////////////////////////// + + +class HopperWgmmaGemmKernel: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Hopper GPUs. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Data type requirements: + - For 16-bit types: A and B must have the same data type + - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit + - Float8 types only support k-major layout + + :note: Supported data types: + - Float16 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulation types: + - Float32 (for all floating point inputs) + + :note: Constraints: + - CTA tile M must be 64/128 + - CTA tile N must be 64/128/256 + - CTA tile K must be 64 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 4 + + Example: + >>> gemm = HopperWgmmaGemmKernel( + ... acc_dtype=cutlass.Float32, + ... tile_shape_mn=(128, 256), + ... cluster_shape_mn=(1, 1) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + def __init__( + self, + acc_dtype: type[cutlass.Numeric], + tile_shape_mn: tuple[int, int], + cluster_shape_mn: tuple[int, int], + ): + """ + Initializes the configuration for a Hopper dense GEMM kernel. + + This configuration includes data types for operands, tile shape, cluster configuration, + and thread layout. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype = acc_dtype + + self.cluster_shape_mn = cluster_shape_mn + self.mma_inst_shape_mn = None + # K dimension is deferred in _setup_attributes + self.tile_shape_mnk = (*tile_shape_mn, 1) + # For large tile size, using two warp groups is preferred because using only one warp + # group may result in register spill + self.atom_layout_mnk = ( + (2, 1, 1) + if self.tile_shape_mnk[0] > 64 and self.tile_shape_mnk[1] > 128 + else (1, 1, 1) + ) + self.num_mcast_ctas_a = None + self.num_mcast_ctas_b = None + self.is_a_mcast = False + self.is_b_mcast = False + self.tiled_mma = None + + self.occupancy = 1 + self.mma_warp_groups = math.prod(self.atom_layout_mnk) + self.num_threads_per_warp_group = 128 + self.threads_per_cta = self.mma_warp_groups * self.num_threads_per_warp_group + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + + # check the cta tile shape + if self.tile_shape_mnk[0] not in [64, 128]: + raise ValueError("CTA tile shape M must be 64/128") + if self.tile_shape_mnk[1] not in [64, 128, 256]: + raise ValueError("CTA tile shape N must be 64/128/256") + + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1]), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.tile_shape_mnk = ( + self.tile_shape_mnk[0], + self.tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cta_layout_mnk = cute.make_layout((*self.cluster_shape_mn, 1)) + self.num_mcast_ctas_a = self.cluster_shape_mn[1] + self.num_mcast_ctas_b = self.cluster_shape_mn[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + is_cooperative = self.atom_layout_mnk == (2, 1, 1) + self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.tile_shape_mnk, self.c_dtype, is_cooperative=is_cooperative + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.a_dtype, + self.b_dtype, + self.smem_capacity, + self.occupancy, + ) + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.epi_stage, + ) + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = a.element_type + self.b_dtype = b.element_type + self.c_dtype = c.element_type + self.a_layout = utils.LayoutEnum.from_tensor(a) + self.b_layout = utils.LayoutEnum.from_tensor(b) + self.c_layout = utils.LayoutEnum.from_tensor(c) + + if cutlass.const_expr( + self.a_dtype.width == 16 and self.a_dtype != self.b_dtype + ): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if cutlass.const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError( + f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}" + ) + if cutlass.const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError(f"a_dtype should be float16 or float8") + + self._setup_attributes() + + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + a, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mn[1], + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + b, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mn[0], + ) + + tma_atom_c, tma_tensor_c = self._make_tma_store_atoms_and_tensors( + c, + self.epi_smem_layout_staged, + self.epi_tile, + ) + + grid = self._compute_grid(c, self.tile_shape_mnk, self.cluster_shape_mn) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.ab_stage * 2 + ] + sA: cute.struct.Align[ + cute.struct.MemRange[ + self.a_dtype, cute.cosize(self.a_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[ + self.b_dtype, cute.cosize(self.b_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + self.tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + epi_smem_layout_staged: cute.ComposedLayout, + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_c: TMA copy atom for C tensor + :type tma_atom_c: cute.CopyAtom + :param mC_mnl: Output tensor C + :type mC_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cta_layout_mnk: CTA layout + :type cta_layout_mnk: cute.Layout + :param a_smem_layout_staged: Shared memory layout for A + :type a_smem_layout_staged: cute.ComposedLayout + :param b_smem_layout_staged: Shared memory layout for B + :type b_smem_layout_staged: cute.ComposedLayout + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + """ + + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + + # /////////////////////////////////////////////////////////////////////////////// + # Get cta/warp/thread idx + # /////////////////////////////////////////////////////////////////////////////// + bidx, bidy, bidz = cute.arch.block_idx() + tidx, _, _ = cute.arch.thread_idx() + + cidx, cidy, _ = cute.arch.cluster_idx() + cdimx, cdimy, _ = cute.arch.cluster_dim() + cluster_id = cidx + cdimx * cidy + + # CTA Swizzle to promote L2 data reuse + group_size_m = 8 + s_shape = ( + (group_size_m, cdimx // group_size_m), + cdimy, + ) + s_stride = ((1, cdimy * group_size_m), group_size_m) + s_layout = cute.make_layout(s_shape, stride=s_stride) + num_reg_cids = cute.size(s_shape) + cid_m, cid_n = s_layout.get_flat_coord(cluster_id % num_reg_cids) + + # Deal with the tail part + if cluster_id >= num_reg_cids: + tail_size_m = cdimx % group_size_m + tail_layout = cute.make_layout( + (tail_size_m, cdimy), stride=(1, tail_size_m) + ) + tail_cid = cluster_id - num_reg_cids + tail_cid_m, tail_cid_n = tail_layout.get_flat_coord(tail_cid) + cid_m = cute.size(s_shape, mode=[0]) + tail_cid_m + cid_n = tail_cid_n + + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * self.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * self.cluster_shape_mn[1] + bidx_in_cluster[1] + + tile_coord_mnkl = (pid_m, pid_n, None, bidz) + cta_rank_in_cluster = cute.arch.make_warp_uniform( + cute.arch.block_idx_in_cluster() + ) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + # /////////////////////////////////////////////////////////////////////////////// + # Get mcast mask + # /////////////////////////////////////////////////////////////////////////////// + a_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cta_layout_mnk, cluster_coord_mnk, mode=0 + ) + + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + tma_copy_bytes = cute.size_in_bytes( + self.a_dtype, a_smem_layout + ) + cute.size_in_bytes(self.b_dtype, b_smem_layout) + + # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) + # ///////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + # mbar arrays + mainloop_pipeline_array_ptr = storage.mainloop_pipeline_array_ptr.data_ptr() + + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + num_warps = self.threads_per_cta // 32 + consumer_arrive_cnt = mcast_size * num_warps + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + + cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) + mainloop_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=mainloop_pipeline_array_ptr, + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cta_layout_vmnk, + ) + + # Cluster arrive after barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_arrive_relaxed() + + # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B + # /////////////////////////////////////////////////////////////////////////////// + sA = storage.sA.get_tensor( + a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner + ) + sB = storage.sB.get_tensor( + b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner + ) + sC_ptr = cute.recast_ptr( + sA.iterator, epi_smem_layout_staged.inner, dtype=self.c_dtype + ) + sC = cute.make_tensor(sC_ptr, epi_smem_layout_staged.outer) + + # /////////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////////// + # (bM, bK, RestK) + gA_mkl = cute.local_tile( + mA_mkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, None, 1) + ) + # (bN, bK, RestK) + gB_nkl = cute.local_tile( + mB_nkl, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1) + ) + # (bM, bN) + gC_mnl = cute.local_tile( + mC_mnl, self.tile_shape_mnk, tile_coord_mnkl, proj=(1, 1, None) + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C + # ////////////////////////////////////////////////////////////////////////////// + warp_group_idx = cute.arch.make_warp_uniform( + tidx // self.num_threads_per_warp_group + ) + warp_group_thread_layout = cute.make_layout( + self.mma_warp_groups, stride=self.num_threads_per_warp_group + ) + thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx)) + + tCgC = thr_mma.partition_C(gC_mnl) + + # ////////////////////////////////////////////////////////////////////////////// + # Partition shared tensor for TMA load A/B + # ////////////////////////////////////////////////////////////////////////////// + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + sA_for_tma_partition = cute.group_modes(sA, 0, 2) + gA_for_tma_partition = cute.group_modes(gA_mkl, 0, 2) + tAsA, tAgA_mkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + sA_for_tma_partition, + gA_for_tma_partition, + ) + + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + sB_for_tma_partition = cute.group_modes(sB, 0, 2) + gB_for_tma_partition = cute.group_modes(gB_nkl, 0, 2) + tBsB, tBgB_nkl = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + sB_for_tma_partition, + gB_for_tma_partition, + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Make fragments + # ////////////////////////////////////////////////////////////////////////////// + tCsA = thr_mma.partition_A(sA) + tCsB = thr_mma.partition_B(sB) + tCrA = tiled_mma.make_fragment_A(tCsA) + tCrB = tiled_mma.make_fragment_B(tCsB) + + acc_shape = tCgC.shape + accumulators = cute.make_fragment(acc_shape, self.acc_dtype) + + # /////////////////////////////////////////////////////////////////////////////// + # Cluster wait + # /////////////////////////////////////////////////////////////////////////////// + # cluster wait for barrier init + if cute.size(self.cluster_shape_mn) > 1: + cute.arch.cluster_wait() + else: + cute.arch.sync_threads() + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch + # ///////////////////////////////////////////////////////////////////////////// + k_tile_cnt = cute.size(gA_mkl, mode=[2]) + prefetch_k_tile_cnt = cutlass.max(cutlass.min(self.ab_stage, k_tile_cnt), 0) + + mainloop_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if warp_idx == 0: + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch TMA load + # ///////////////////////////////////////////////////////////////////////////// + for prefetch_idx in cutlass.range(prefetch_k_tile_cnt, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # Prologue MMAs + # ///////////////////////////////////////////////////////////////////////////// + k_pipe_mmas = 1 + + mainloop_consumer_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + mainloop_consumer_release_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, False) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_tile in cutlass.range_constexpr(k_pipe_mmas): + # Wait for A/B buffer to be ready + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + + cute.nvgpu.warpgroup.fence() + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + tiled_mma.set(cute.nvgpu.warpgroup.Field.ACCUMULATE, True) + + cute.nvgpu.warpgroup.commit_group() + mainloop_consumer_read_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + + # ///////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # ///////////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range(k_pipe_mmas, k_tile_cnt, 1, unroll=1): + # ///////////////////////////////////////////////////////////////////////////// + # Wait for TMA copies to complete + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.consumer_wait( + mainloop_consumer_read_state, peek_ab_full_status + ) + # ///////////////////////////////////////////////////////////////////////////// + # WGMMA + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.fence() + for k_block_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_block_coord = ( + None, + None, + k_block_idx, + mainloop_consumer_read_state.index, + ) + tCrA_1phase = tCrA[k_block_coord] + tCrB_1phase = tCrB[k_block_coord] + + cute.gemm( + tiled_mma, + accumulators, + tCrA_1phase, + tCrB_1phase, + accumulators, + ) + + cute.nvgpu.warpgroup.commit_group() + # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete + cute.nvgpu.warpgroup.wait_group(k_pipe_mmas) + + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + + mainloop_consumer_read_state.advance() + mainloop_consumer_release_state.advance() + + peek_ab_full_status = cutlass.Boolean(1) + if mainloop_consumer_read_state.count < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait( + mainloop_consumer_read_state + ) + # ///////////////////////////////////////////////////////////////////////////// + # TMA load + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == 0 and mainloop_producer_state.count < k_tile_cnt: + # ///////////////////////////////////////////////////////////////////////////// + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # ///////////////////////////////////////////////////////////////////////////// + mainloop_pipeline.producer_acquire(mainloop_producer_state) + + # ///////////////////////////////////////////////////////////////////////////// + # Slice to global/shared memref to current k_tile + # ///////////////////////////////////////////////////////////////////////////// + tAgA_k = tAgA_mkl[(None, mainloop_producer_state.count)] + tAsA_pipe = tAsA[(None, mainloop_producer_state.index)] + + tBgB_k = tBgB_nkl[(None, mainloop_producer_state.count)] + tBsB_pipe = tBsB[(None, mainloop_producer_state.index)] + + # ///////////////////////////////////////////////////////////////////////////// + # TMA load A/B + # ///////////////////////////////////////////////////////////////////////////// + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=a_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=mainloop_pipeline.producer_get_barrier( + mainloop_producer_state + ), + mcast_mask=b_mcast_mask, + ) + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + # ///////////////////////////////////////////////////////////////////////////// + # EPILOG + # ///////////////////////////////////////////////////////////////////////////// + cute.nvgpu.warpgroup.wait_group(0) + + if cute.size(self.cluster_shape_mn) > 1: + # Wait for all threads in the cluster to finish, avoid early release of smem + cute.arch.cluster_arrive() + cute.arch.cluster_wait() + else: + # For cluster that has a single thread block, it might have more than one warp groups. + # Wait for all warp groups in the thread block to finish, because smem for tensor A in + # the mainloop is reused in the epilogue. + cute.arch.sync_threads() + + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + self.c_layout, + elem_ty_d=self.c_dtype, + elem_ty_acc=self.acc_dtype, + ) + + copy_atom_C = cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + 4, + ), + self.c_dtype, + ) + + tiled_copy_C_Atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + + tiled_copy_r2s = cute.make_tiled_copy_S( + copy_atom_r2s, + tiled_copy_C_Atom, + ) + + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(accumulators) + + # Allocate D registers. + rD_shape = cute.shape(thr_copy_r2s.partition_S(sC)) + tRS_rD_layout = cute.make_layout(rD_shape[:3]) + tRS_rD = cute.make_fragment_like(tRS_rD_layout, self.acc_dtype) + size_tRS_rD = cute.size(tRS_rD) + + sepi_for_tma_partition = cute.group_modes(sC, 0, 2) + tCgC_for_tma_partition = cute.zipped_divide(gC_mnl, self.epi_tile) + + bSG_sD, bSG_gD = cute.nvgpu.cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sepi_for_tma_partition, + tCgC_for_tma_partition, + ) + + epi_tile_num = cute.size(tCgC_for_tma_partition, mode=[1]) + epi_tile_shape = tCgC_for_tma_partition.shape[1] + epi_tile_layout = cute.make_layout( + epi_tile_shape, stride=(epi_tile_shape[1], 1) + ) + + # Initialize tma store c_pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, self.threads_per_cta, self.threads_per_cta + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, + producer_group=c_producer_group, + ) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from accumulators to D registers + for epi_v in cutlass.range_constexpr(size_tRS_rD): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * size_tRS_rD + epi_v] + + # Type conversion + tRS_rD_out = cute.make_fragment_like(tRS_rD_layout, self.c_dtype) + acc_vec = tRS_rD.load() + tRS_rD_out.store(acc_vec.to(self.c_dtype)) + + # Copy from D registers to shared memory + epi_buffer = epi_idx % cute.size(tRS_sD, mode=[3]) + cute.copy( + tiled_copy_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)] + ) + + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + # barrier for sync + cute.arch.barrier() + + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if warp_idx == 0: + cute.copy( + tma_atom_c, + bSG_sD[(None, epi_buffer)], + bSG_gD[(None, gmem_coord)], + ) + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + + cute.arch.barrier() + + if warp_idx == 0: + c_pipeline.producer_tail() + + return + + @staticmethod + def _compute_stages( + tile_shape_mnk: tuple[int, int, int], + a_dtype: type[cutlass.Numeric], + b_dtype: type[cutlass.Numeric], + smem_capacity: int, + occupancy: int, + ) -> tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: tuple[int, int] + """ + + epi_stage = 4 + # epi_smem will reuse smem ab. + epi_bytes = 0 + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + + cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + + ab_stage = ( + smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes + ) // ab_bytes_per_stage + return ab_stage, epi_stage + + @staticmethod + def _sm90_compute_tile_shape_or_override( + tile_shape_mnk: tuple[int, int, int], + element_type: type[cutlass.Numeric], + is_cooperative: bool = False, + epi_tile_override: tuple[int, int] | None = None, + ) -> tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if is_cooperative: + tile_m = min(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(32, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + n_perf = 64 if element_type.width == 8 else 32 + tile_m = min(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = min(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: tuple[int, int, int], + epi_tile: tuple[int, int], + a_dtype: type[cutlass.Numeric], + a_layout: utils.LayoutEnum, + b_dtype: type[cutlass.Numeric], + b_layout: utils.LayoutEnum, + ab_stage: int, + c_dtype: type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + epi_stage: int, + ) -> tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout]: + """Create shared memory layouts for A, B, and C tensors. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout enum for matrix A + :type a_layout: utils.LayoutEnum + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout enum for matrix B + :type b_layout: utils.LayoutEnum + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param c_dtype: Data type for output matrix C + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout enum for the output matrix C + :type c_layout: utils.LayoutEnum + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + + a_is_k_major = ( + a_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + b_is_k_major = ( + b_layout.sm90_mma_major_mode() == cute.nvgpu.warpgroup.OperandMajorMode.K + ) + a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + a_layout, + a_dtype, + a_major_mode_size, + ), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + + b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + b_layout, + b_dtype, + b_major_mode_size, + ), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + c_smem_shape = epi_tile + c_major_mode_size = epi_tile[1] if c_layout.is_n_major_c() else epi_tile[0] + c_smem_layout_atom = cute.nvgpu.warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + c_layout, + c_dtype, + c_major_mode_size, + ), + c_dtype, + ) + epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(c_smem_shape, epi_stage), + order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + ) + + return a_smem_layout_staged, b_smem_layout_staged, epi_smem_layout_staged + + @staticmethod + def _compute_grid( + c: cute.Tensor, + tile_shape_mnk: tuple[int, int, int], + cluster_shape_mn: tuple[int, int], + ) -> tuple[int, int, int]: + """Compute grid shape for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + + :return: Grid shape for kernel launch. + :rtype: tuple[int, int, int] + """ + + c_shape = (tile_shape_mnk[0], tile_shape_mnk[1]) + gc = cute.zipped_divide(c, tiler=c_shape) + cluster_shape_mnl = (*cluster_shape_mn, 1) + clusters = cute.ceil_div(cute.get(gc.layout, mode=[1]).shape, cluster_shape_mnl) + grid = tuple(x * y for x, y in zip(clusters, cluster_shape_mnl)) + return grid + + @staticmethod + def _make_tma_store_atoms_and_tensors( + tensor_c: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: tuple[int, int], + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for C tensor storage. + + :param tensor_c: Output tensor C + :type tensor_c: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + c_cta_v_layout = cute.composition( + cute.make_identity_layout(tensor_c.shape), epi_tile + ) + tma_atom_c, tma_tensor_c = cute.nvgpu.cpasync.make_tiled_tma_atom( + cute.nvgpu.cpasync.CopyBulkTensorTileS2GOp(), + tensor_c, + epi_smem_layout, + c_cta_v_layout, + ) + + return tma_atom_c, tma_tensor_c + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: tuple[int, int], + mcast_dim: int, + ) -> tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cute.nvgpu.cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cute.nvgpu.cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + # tested a_dtype + if a_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested b_dtype + if b_dtype not in { + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested acc_dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + is_valid = False + # tested c_dtype + if c_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) + if a_dtype.width != b_dtype.width: + is_valid = False + + # for Float8 types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or ( + b_dtype.width == 8 and b_major != "k" + ): + is_valid = False + + return is_valid + + +def run( + mnkl: Tuple[int, int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + tile_shape_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + tolerance: float, + warmup_iterations: int, + iterations: int, + skip_ref_check: bool, + use_cold_l2: bool = False, + **kwargs, +): + """ + Prepare A/B/C tensors, launch GPU kernel, and reference checking. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param a_dtype: Data type for input tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: Data type for input tensor B + :type b_dtype: Type[cutlass.Numeric] + :param c_dtype: Data type for output tensor C + :type c_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/c_major: Memory layout of tensor A/B/C + :type a_major/b_major/c_major: str + :param tile_shape_mn: CTA tile shape (M, N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster shape (M, N) + :type cluster_shape_mn: Tuple[int, int] + :param tolerance: Tolerance value for reference validation comparison + :type tolerance: float + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :param use_cold_l2: Whether to use circular buffer strategy to ensure cold L2 cache, defaults to False + :type use_cold_l2: bool, optional + :return: Execution time of the GEMM kernel in microseconds + :rtype: float + """ + + print(f"Running Hopper Dense GEMM with:") + print(f"mnkl: {mnkl}") + print( + f"A dtype: {a_dtype}, B dtype: {b_dtype}, C dtype: {c_dtype}, Acc dtype: {acc_dtype}" + ) + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {c_major}") + print(f"Tile Shape: {tile_shape_mn}, Cluster Shape: {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + print(f"Use cold L2: {use_cold_l2}") + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported types + if not HopperWgmmaGemmKernel.is_valid_dtypes( + a_dtype, b_dtype, acc_dtype, c_dtype, a_major, b_major + ): + raise TypeError( + f"Skipping due to unsupported combination of types and majors: {a_dtype}, {b_dtype}, {acc_dtype}, {c_dtype}, {a_major=}, {b_major=}" + ) + + # Prepare pytorch tensors: A, B (random from 0 to 2) and C (all zero) + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor( + l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True + ): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else : (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = ( + cutlass_torch.dtype(dtype) + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.uint8 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass.torch.create_and_permute_torch_tensor( + shape, + torch_dtype, + permute_order=permute_order, + init_type=cutlass.torch.TensorInitType.RANDOM, + init_config=cutlass.torch.RandomInitConfig( + min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + ), + ) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + cute_tensor = from_dlpack(torch_tensor, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic( + leading_dim=(0 if is_mode0_major else 1) + ) + cute_tensor = cutlass.torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor + + a, mA, a_torch = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + b, mB, b_torch = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + c, mC, c_torch = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + + gemm = HopperWgmmaGemmKernel(acc_dtype, tile_shape_mn, cluster_shape_mn) + + torch_stream = torch.cuda.Stream() + stream = cuda.CUstream(torch_stream.cuda_stream) + # compile gemm kernel + compiled_gemm = cute.compile(gemm, mA, mB, mC, stream) + + if not skip_ref_check: + # execution + compiled_gemm(mA, mB, mC, stream) + + torch.cuda.synchronize() + + # Ref check + ref = (torch.einsum("mkl,nkl->mnl", a, b)).cpu() + + if c_dtype in (cutlass.Float8E4M3FN, cutlass.Float8E5M2): + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if c_major == "n" else (2, 1, 0) + shape = (l, m, n) if c_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_c_tensor = from_dlpack( + f8_torch_tensor, assumed_align=16 + ).mark_layout_dynamic(leading_dim=(1 if c_major == "n" else 0)) + ref_c_tensor.element_type = c_dtype + ref_c_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_c_tensor, + c_dtype, + is_dynamic_layout=True, + ) + ref_c = f8_torch_tensor.cpu() + else: + ref_c = ref.to(cutlass_torch.dtype(c_dtype)) + + torch.testing.assert_close(c_torch.cpu(), ref_c, atol=tolerance, rtol=1e-03) + + def generate_tensors(): + _, mA_workspace, _ = create_and_permute_tensor(l, m, k, a_major == "m", a_dtype) + _, mB_workspace, _ = create_and_permute_tensor(l, n, k, b_major == "n", b_dtype) + _, mC_workspace, _ = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + return testing.JitArguments(mA_workspace, mB_workspace, mC_workspace, stream) + + workspace_count = 1 + if use_cold_l2: + one_workspace_bytes = ( + a_torch.numel() * a_torch.element_size() + + b_torch.numel() * b_torch.element_size() + + c_torch.numel() * c_torch.element_size() + ) + workspace_count = testing.get_workspace_count( + one_workspace_bytes, warmup_iterations, iterations + ) + + exec_time = testing.benchmark( + compiled_gemm, + workspace_generator=generate_tensors, + workspace_count=workspace_count, + stream=stream, + warmup_iterations=warmup_iterations, + iterations=iterations, + ) + + return exec_time # Return execution time in microseconds + + +if __name__ == "__main__": + args = parse_arguments() + run( + args.mnkl, + args.a_dtype, + args.b_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.c_major, + args.tile_shape_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.use_cold_l2, + ) + print("PASS") diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp new file mode 100644 index 0000000000000000000000000000000000000000..60d5b468e2f28dd917f2776a88451385aa237d1c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/axpby.hpp @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template > +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor && y, + PrdTensor const& p = {}) +{ + return axpby(alpha, x, beta, y, p); +} + +// +// AXPBY +// +template > +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor & y, + PrdTensor const& p = {}) +{ + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; + } (); + + CUTE_UNROLL + for (int i = 0; i < size(x); ++i) { + if (p(i)) { + y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); + } + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp new file mode 100644 index 0000000000000000000000000000000000000000..225c46ec878215ccfae57504021df5352fbf10dd --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/clear.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::fill + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +clear(Tensor&& tensor) +{ + return clear(tensor); +} + +// +// Set elements to zero +// +template +CUTE_HOST_DEVICE +void +clear(Tensor& tensor) +{ + using T = typename Tensor::value_type; + + fill(tensor, T{}); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2653916e3ce3eff3a26c44f77c1d67f8f92d0282 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_copy.hpp @@ -0,0 +1,338 @@ +/*************************************************************************************************** +* Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::logical_divide +#include // cute::Swizzle +#include // cute::get_nonswizzle_portion +#include // cute::Tensor +#include +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + auto N = size(dst); + auto R = N % Int{}; + if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds + dst[tid] = src[tid]; + } + CUTE_UNROLL + for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return naive_cooperative_copy(tid, src, dst); +} + +// A heuristic to determine a "good" permutation of two tensors for later vectorization and thr-assignment +template +CUTE_HOST_DEVICE constexpr +auto +heuristic_permutation(Tensor const& a, + Tensor const& b) +{ + constexpr bool swizzleA = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + constexpr bool swizzleB = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + auto a_inv = right_inverse(get_nonswizzle_portion(a.layout())); + auto b_inv = right_inverse(get_nonswizzle_portion(b.layout())); + + constexpr uint8_t scoreA = (uint8_t(swizzleA) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(a_inv) > size(b_inv)) << 0); + + constexpr uint8_t scoreB = (uint8_t(swizzleB) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(b_inv) > size(a_inv)) << 0); + + if constexpr (scoreA >= scoreB) { + return a_inv; + } else { + return b_inv; + } +} + +// cooperative_copy(thr_idx, src, dst) +// Use NumThreads to copy Tensor src to Tensor dst with element-wise vectorization up to MaxVecBits. +// @pre 0 <= @a tid < NumThreads +// @pre Tensors @a src and @a dst are aligned up to MaxVecBits. +// That is, pointers and dynamic strides are assumed to be aligned up to MaxVecBits. +// +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + // Assumes the shapes are static, can generalize/fallback + CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); + CUTE_STATIC_ASSERT_V(size(src) == size(dst)); + // Assumes the types are the same, can generalize/fallback + static_assert(cute::is_same::value); + static_assert(MaxVecBits == sizeof_bits_v || + MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance."); + // Check that the tensors are likely shared across threads: either gmem or smem + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem source tensor."); + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem destination tensor."); + // Precondition on tid in DEBUG + assert(tid < NumThreads); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // The common layout of the two tensors that can be vectorized over elements and threads + // vidx -> coord + auto common_layout = heuristic_permutation(src, dst); + + // Apply + // (V, rest) + Tensor src_a = coalesce(logical_divide(src, common_layout), Shape<_1,_1>{}); + Tensor dst_a = coalesce(logical_divide(dst, common_layout), Shape<_1,_1>{}); + + // + // Determine vectorization of elems and thrs based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // + + // The number of elements and number of bits + constexpr int elem_bits = sizeof_bits_v; + constexpr int total_elem = size(SrcLayout{}); + + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src_a, dst_a))::value; + +#if 0 + if (thread0()) { + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("src_a: "); print(src_a); print("\n"); + print(" "); print("dst_a: "); print(dst_a); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // + if constexpr (total_elem % NumThreads != 0) { + // Not attempting to find a partitioning pattern, fallback to dynamically indexed slowpath + + if constexpr (common_elem > 1 && MaxVecBits > elem_bits) { + // If the vectorization is non-trivial and divides the maximum vectorizations, then vectorize + constexpr auto max_align_src = elem_bits * decltype(max_alignment(src_a.layout()))::value; + constexpr auto max_align_dst = elem_bits * decltype(max_alignment(dst_a.layout()))::value; + constexpr auto vec_bits = gcd(max_align_src, max_align_dst, MaxVecBits); + using VecType = uint_bit_t; + + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert((vec_bits >= 8), "No support for subbyte copying"); + + Tensor src_v = recast(src_a); + Tensor dst_v = recast(dst_a); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- naive\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + naive_cooperative_copy(tid, src_v, dst_v); + } else { + naive_cooperative_copy(tid, src_a, dst_a); + } + } else { + // If the tensors can be equally partitioned by the threads, + // compute vectorization widths in elements and threads. + + // If there are too many threads to allow a full vectorized copy, trunc the vectorization + constexpr int total_bits = total_elem * elem_bits; + constexpr int max_bits_per_thr = total_bits / NumThreads; + // At least elem_bits, at most common_bits + constexpr int common_bits = common_elem * elem_bits; + constexpr int vec_bits = cute::max(elem_bits, cute::gcd(common_bits, int(MaxVecBits), max_bits_per_thr)); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert(vec_bits >= 8, "No support for subbyte copying"); + + using VecType = uint_bit_t; + constexpr int vec_elem = vec_bits / elem_bits; + + constexpr int vec_thrs = cute::min(int(NumThreads), total_elem / vec_elem); + + // + // Determine the partitioning patterns for the vec_elems and vec_thrs + // + + // Distribute the rest of the V*T to some consistent portion outside of the common_layout, if needed + auto common_domain_src = domain_distribute(shape(src_a), Int{}); + auto common_domain_dst = domain_distribute(shape(dst_a), Int{}); + + // Make sure for now, could fall back here instead + CUTE_STATIC_ASSERT_V(size(common_domain_src) == Int{}); + CUTE_STATIC_ASSERT_V(compatible(common_domain_src, common_domain_dst) || + compatible(common_domain_dst, common_domain_src)); + // Use the "more specific" domain for the extra elements of V*T + auto common_domain = conditional_return(compatible(common_domain_src, common_domain_dst), + common_domain_dst, common_domain_src); + + // Construct the tiler + auto tiler_vt = common_domain.with_shape(Int{}, Int{}); + + // Apply and slice + Tensor src_v = logical_divide(src_a, tiler_vt)(make_coord(_,tid),_); + Tensor dst_v = logical_divide(dst_a, tiler_vt)(make_coord(_,tid),_); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("tiler_vt: "); print(tiler_vt); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // If we're using all threads (static) or the tid is in-range (dynamic) + if (vec_thrs == NumThreads or tid < vec_thrs) { + auto src_c = recast(src_v); + auto dst_c = recast(dst_v); + return copy(cpy, src_c, dst_c); + } + } +} + + +// Default max-vectorization size to value_type size +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + constexpr uint32_t MaxVecBits = sizeof_bits_v; + return cooperative_copy(tid, src, dst, cpy); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f0a993593d7d90eae9905f0f5bbbf890dcac2873 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/cooperative_gemm.hpp @@ -0,0 +1,553 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#include + +namespace cute +{ + +// +// Cooperative Shared-Memory GEMMs +// + +namespace detail { + +// Slow fallback path: +template +CUTE_HOST_DEVICE +void +epilogue_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename ThrMMA::ValTypeC; + CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v); + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i) + : alpha * tCrC(i) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +template +CUTE_HOST_DEVICE +void +epilogue_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C + SmemCopyLdOpC const& sC_copy_ld_op, + SmemCopyStOpC const& sC_copy_st_op) +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename TRC::value_type; + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + Tensor tCrD = make_fragment_like(tCrC); + Tensor tCrDi = make_fragment_like(tCrD); + + if(!isBetaZero) { + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_S(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_D(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCsC, tCrDi_copy_view); + + // Transform C on/after load + cute::transform(tCrDi, tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, tCrDi, sC_store_op); + + auto smem_tiled_copy_C = make_tiled_copy_C(Copy_Atom{}, thr_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_thread_slice(thread_idx); + Tensor tCsC = smem_thr_copy_C.partition_D(sC); + Tensor tCrDi_copy_view = smem_thr_copy_C.retile_S(tCrDi); + CUTE_STATIC_ASSERT_V(size<1>(tCsC) == size<1>(tCrDi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsC) == size<2>(tCrDi_copy_view)); // CPY_N + copy(smem_tiled_copy_C, tCrDi_copy_view, tCsC); +} + +// Predicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + // + // MMA Partitioning + // + + // Partition the sA, sB, and sC tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + // + // PREDICATION + // + + // Create coordinate tensors for the problem + Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k) + + // Repeat partitioning with thr_mma + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) + + // Allocate the preds for MMA- and MMA_MN-modes + Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); + Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); + // Populate the predicates on M and N + CUTE_UNROLL + for (int i = 0; i < size(tCpA); ++i) { + tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); + } + CUTE_UNROLL + for (int i = 0; i < size(tCpB); ++i) { + tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); + } + // + // PREFETCH k_block = 0 + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{}; + } + } + // + // MAINLOOP + // + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block + { + int k_next = k_block + 1; // Load k_next block + + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{}; + } + } + } + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +// Unpredicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op, + SmemCopyOpB const& sB_copy_op) +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + + // + // MMA Partitioning + // + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrAi = make_fragment_like(tCrA); + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCrBi = make_fragment_like(tCrB); + + using CopyOpAType = SmemCopyOpA; + using CopyOpBType = SmemCopyOpB; + + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); + Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); + Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K + // + // PREFETCH + // + + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{})); + // + // MAINLOOP + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; // statically unrolled + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next)); + } + + // Transform A and B, relying on the compiler to remove in case of identity ops + cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op); + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +} // end namespace detail + +// C passed as a shared memory tensor +// Epilogue included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC + + // Clear accumulators + clear(tCrC); + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + detail::epilogue_no_predication( + thread_idx, thr_mma,alpha, tCrC, beta, sC, sC_load_op, sC_store_op, sC_copy_ld_op, sC_copy_st_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); + } +} + +// C already partitioned into registers on input +// It can be passed non-empty +// Epilogue not included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + // Check if input C fragment is compatible with thr_mma and problem size + using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB)))); + CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC))); + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor && sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyLdOpC const& sC_copy_ld_op = {}, + SmemCopyStOpC const& sC_copy_st_op = {}) +{ + cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op, + sA_copy_op, sB_copy_op, sC_copy_ld_op, sC_copy_st_op); +} + +// Legacy overload of cute::gemm for backwards-compatibility +template +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + + // Goes directly to the slow path to avoid getting thread_idx from thr_mma + detail::cooperative_gemm_predication( + thr_mma, sA, sB, sC, sA_load_op, sB_load_op + ); + + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..66fc49c0e67f5b6197bd488276989c0050a8c149 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/copy.hpp @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// copy_if -- Predicated Copy +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + dst(i) = static_cast(static_cast(src(i))); + } + } +} + +// +// copy_if -- Predicated CopyAtom +// + +// Predicate Tensor is an Actual Tensor +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + Tensor const& prd, // ([V],Rest...) + Tensor const& src, // ( V, Rest...) + Tensor & dst) // ( V, Rest...) +{ + if constexpr (PrdLayout::rank == SrcLayout::rank - 1) { + // Back-compat ONLY -- Delete? + copy_if(copy_atom, make_tensor(prd.data(), prepend(prd.layout(), Layout<_1,_0>{})), src, dst); + } else { + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + static_assert(SrcLayout::rank == PrdLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(prd, src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor prd_v = group_modes<1,R>(prd); + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(prd_v(_,i), src_v(_,i), dst_v(_,i)); + } + } + } +} + +template +[[deprecated("Use a bool-tensor or transform-tensor as predication.")]] +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PredTensor const& pred, // (Rest...) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + Tensor tpred = cute::lazy::transform(make_tensor(counting_iterator{}, replace<0>(shape(dst), _1{})), pred); + return copy_if(copy_atom, tpred, src, dst); +} + +// +// copy_if -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy_if(AutoCopyAsync const& cpy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcElemWithConst = remove_reference_t; + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + auto copy_op = []() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType)) { + if constexpr (is_const_v && sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEGLOBAL{}; + } else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + } else { + return UniversalCopy{}; + } + + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; +#endif + }(); + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } + } +} + +// +// copy -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy(AutoCopyAsync const& cpy, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + copy_if(cpy, constant_fn{}, src, dst); +} + +// +// copy -- CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + + if constexpr (is_static::value && is_static::value) { + CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v)); + + // AutoFilter on the Rest-mode + auto dst_null = nullspace(layout<1>(dst_v)); + + Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest)) + Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest)) + + CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n)); + CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error"); + CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{})); + CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{})); + + Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + + CUTE_STATIC_ASSERT_V( size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); + CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); + + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_c); ++i) { + copy_atom.call(src_c(_,i), dst_c(_,i)); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } +} + +//////////////////////////////////////////////////////// +// Special Auto-Vectorizing, Auto-Filtering Overloads // +//////////////////////////////////////////////////////// + +// Specialization for AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(AutoVectorizingCopyWithAssumedAlignment const&, + Tensor const& src, + Tensor & dst) +{ + constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit write!"); + + if constexpr (common_elem > 1) + { + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + + if constexpr ((vec_bits % 8) == 0 && sizeof_bits_v < Int{}) + { + // If more than one element vectorizes to a multiple of 8bits that is larger than the value_type, then recast and copy + using VecType = uint_bit_t; + + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); + return copy_if(constant_fn{}, src_v, dst_v); + } else { + return copy_if(constant_fn{}, src, dst); + } + } else { + return copy_if(constant_fn{}, src, dst); + } +} + +template +struct AutoFilter { + Base const& base; + CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {} +}; + +// Specialization for AutoFilter +template +CUTE_HOST_DEVICE +void +copy(AutoFilter const& copy_op, + Tensor const& src, + Tensor & dst) +{ + if constexpr (is_constant::value) { + auto dst_null = nullspace(dst.layout()); + + Tensor dst_n = zipped_divide(dst, dst_null); + Tensor src_n = zipped_divide(src, dst_null); + + CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous race-condition detected."); + + copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); + } else { + copy(copy_op.base, src, dst); + } +} + +// Auto-vectorizing copy for static layouts +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned. + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst); + } else { + // Do not assume that dynamic layouts are aligned. + return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); + } +} + +// Auto-vectorizing copy with assumed alignment up to 128bit. +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else { + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } +} + +// Specializaton for Atom AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom>, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +copy(Copy_Traits const& atom, // Copy_Traits may or may not have the memory barrier in it already + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + static_assert(cute::is_same::value); + static_assert((is_gmem::value && is_smem::value) || + (is_smem::value && is_gmem::value), + "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); + // G2S or S2G dispatch + using BULK_COPY_OP = conditional_t::value, + SM90_BULK_COPY_G2S, + SM90_BULK_COPY_S2G>; + + // Find the common subtensor of src and dst + auto tiler = max_common_layout(src, dst); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + using BulkAtom = Copy_Atom, CT_Args...>, SrcType>; + auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; }); + return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, CA_Args...> const& atom, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast const&>(atom), src, dst); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +// +// Decay TiledCopy to CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(TiledCopy const& tiled_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + return copy_if(static_cast(tiled_copy), pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(TiledCopy const& tiled_copy, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast(tiled_copy), src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(ThrCopy const& thr_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) = delete; + +template +CUTE_HOST_DEVICE +void +copy(ThrCopy const& thr_copy, + Tensor const& src, + Tensor & dst) = delete; + +// +// Catch uncaught policies +// + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& cpy, + PredTensor const& prd, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& cpy, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp new file mode 100644 index 0000000000000000000000000000000000000000..37b97f18734e74bcdcc39fede3bb6aeaa26c1338 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/fill.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +fill(Tensor&& tensor, T const& value) +{ + return fill(tensor, value); +} + +namespace detail +{ + +// Prefer fill(tensor.data(), value), if possible +template +CUTE_HOST_DEVICE +auto +fill(Tensor& tensor, T const& value, prefer<1>) + -> decltype(fill(tensor.data(), value)) +{ + fill(tensor.data(), value); +} + +// Default implementation +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value, prefer<0>) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = value; + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value) +{ + return detail::fill(tensor, value, prefer<1>{}); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c56eb5cc80be460895802df09363eb7254a8455 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/functional.hpp @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::max, cute::min +#include // cute::conj + +/** C++14 extensions */ + +namespace cute { + +/**************/ +/** Identity **/ +/**************/ + +struct identity { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg); + } +}; + +template +struct constant_fn { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&&...) const { + return r_; + } + R r_; +}; + +/***********/ +/** Unary **/ +/***********/ + +#define CUTE_LEFT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP static_cast(arg); \ + } \ + } +#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return static_cast(arg) OP ; \ + } \ + } +#define CUTE_NAMED_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP (static_cast(arg)); \ + } \ + } + +CUTE_LEFT_UNARY_OP(unary_plus, +); +CUTE_LEFT_UNARY_OP(negate, -); +CUTE_LEFT_UNARY_OP(bit_not, ~); +CUTE_LEFT_UNARY_OP(logical_not, !); +CUTE_LEFT_UNARY_OP(dereference, *); +CUTE_LEFT_UNARY_OP(address_of, &); +CUTE_LEFT_UNARY_OP(pre_increment, ++); +CUTE_LEFT_UNARY_OP(pre_decrement, --); + +CUTE_RIGHT_UNARY_OP(post_increment, ++); +CUTE_RIGHT_UNARY_OP(post_decrement, --); + +CUTE_NAMED_UNARY_OP(abs_fn, abs); +CUTE_NAMED_UNARY_OP(conjugate, cute::conj); + +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP +#undef CUTE_NAMED_UNARY_OP + +template +struct shift_right_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) >> Shift; + } +}; + +template +struct shift_left_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) << Shift; + } +}; + +/************/ +/** Binary **/ +/************/ + +#define CUTE_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return static_cast(lhs) OP static_cast(rhs); \ + } \ + } +#define CUTE_NAMED_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return OP (static_cast(lhs), static_cast(rhs)); \ + } \ + } + + +CUTE_BINARY_OP(plus, +); +CUTE_BINARY_OP(minus, -); +CUTE_BINARY_OP(multiplies, *); +CUTE_BINARY_OP(divides, /); +CUTE_BINARY_OP(modulus, %); + +CUTE_BINARY_OP(plus_assign, +=); +CUTE_BINARY_OP(minus_assign, -=); +CUTE_BINARY_OP(multiplies_assign, *=); +CUTE_BINARY_OP(divides_assign, /=); +CUTE_BINARY_OP(modulus_assign, %=); + +CUTE_BINARY_OP(bit_and, &); +CUTE_BINARY_OP(bit_or, |); +CUTE_BINARY_OP(bit_xor, ^); +CUTE_BINARY_OP(left_shift, <<); +CUTE_BINARY_OP(right_shift, >>); + +CUTE_BINARY_OP(bit_and_assign, &=); +CUTE_BINARY_OP(bit_or_assign, |=); +CUTE_BINARY_OP(bit_xor_assign, ^=); +CUTE_BINARY_OP(left_shift_assign, <<=); +CUTE_BINARY_OP(right_shift_assign, >>=); + +CUTE_BINARY_OP(logical_and, &&); +CUTE_BINARY_OP(logical_or, ||); + +CUTE_BINARY_OP(equal_to, ==); +CUTE_BINARY_OP(not_equal_to, !=); +CUTE_BINARY_OP(greater, >); +CUTE_BINARY_OP(less, <); +CUTE_BINARY_OP(greater_equal, >=); +CUTE_BINARY_OP(less_equal, <=); + +CUTE_NAMED_BINARY_OP(max_fn, cute::max); +CUTE_NAMED_BINARY_OP(min_fn, cute::min); + +#undef CUTE_BINARY_OP +#undef CUTE_NAMED_BINARY_OP + +/**********/ +/** Fold **/ +/**********/ + +#define CUTE_FOLD_OP(NAME,OP) \ + struct NAME##_unary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (t OP ...); \ + } \ + }; \ + struct NAME##_unary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (... OP t); \ + } \ + }; \ + struct NAME##_binary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (t OP ... OP u); \ + } \ + }; \ + struct NAME##_binary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (u OP ... OP t); \ + } \ + } + +CUTE_FOLD_OP(plus, +); +CUTE_FOLD_OP(minus, -); +CUTE_FOLD_OP(multiplies, *); +CUTE_FOLD_OP(divides, /); +CUTE_FOLD_OP(modulus, %); + +CUTE_FOLD_OP(plus_assign, +=); +CUTE_FOLD_OP(minus_assign, -=); +CUTE_FOLD_OP(multiplies_assign, *=); +CUTE_FOLD_OP(divides_assign, /=); +CUTE_FOLD_OP(modulus_assign, %=); + +CUTE_FOLD_OP(bit_and, &); +CUTE_FOLD_OP(bit_or, |); +CUTE_FOLD_OP(bit_xor, ^); +CUTE_FOLD_OP(left_shift, <<); +CUTE_FOLD_OP(right_shift, >>); + +CUTE_FOLD_OP(bit_and_assign, &=); +CUTE_FOLD_OP(bit_or_assign, |=); +CUTE_FOLD_OP(bit_xor_assign, ^=); +CUTE_FOLD_OP(left_shift_assign, <<=); +CUTE_FOLD_OP(right_shift_assign, >>=); + +CUTE_FOLD_OP(logical_and, &&); +CUTE_FOLD_OP(logical_or, ||); + +CUTE_FOLD_OP(equal_to, ==); +CUTE_FOLD_OP(not_equal_to, !=); +CUTE_FOLD_OP(greater, >); +CUTE_FOLD_OP(less, <); +CUTE_FOLD_OP(greater_equal, >=); +CUTE_FOLD_OP(less_equal, <=); + +#undef CUTE_FOLD_OP + +/**********/ +/** Meta **/ +/**********/ + +template +struct bound_fn { + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(T&& arg) { + return fn_(arg_, static_cast(arg)); + } + + Fn fn_; + Arg arg_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +bind(Fn const& fn, Arg const& arg) { + return bound_fn{fn, arg}; +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp new file mode 100644 index 0000000000000000000000000000000000000000..97839c044f66b7f955605cc6c58efaa0ec45e6ff --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/gemm.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +#include + +/** The gemm algorithm takes four (or three) tensors and computes + * D = A * B + C + * It dispatches based on the number of modes each tensor has: + * + * 1. `(V) x (V) => (V)`. + * The element-wise product of vectors. Dispatches to FMA or MMA. + * 2. `(M) x (N) => (M,N)`. + * The outer product of vectors. Dispatches to [3] with new mode K=(1). + * 3. `(M,K) x (N,K) => (M,N)`. + * The product of matrices. Dispatches to [5] with MMA vector-mode V. + * 4. `(V,M) x (V,N) => (V,M,N)`. + * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). + * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. + * The batched product of matrices. Dispatches to [4] for each (k). + */ + +namespace cute +{ + +// +// Three arguments to four +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(mma, C, A, B, C); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(D, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(mma, C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(mma, D, A, B, C); +} + +// +// Default MMA is UniversalFMA +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + using MMA = MMA_Atom::value_type, + typename Tensor::value_type, + typename Tensor::value_type, + typename Tensor::value_type>>; + + return gemm(MMA{}, D, A, B, C); +} + +// +// Thread-Local Register-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 1 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V) Logical data + Tensor const& A, // (V) Logical data + Tensor const& B, // (V) Logical data + Tensor const& C) // (V) Logical data +{ + // No static assertions on (V), MMA checks compatibility + mma.call(D, A, B, C); +} + +// Dispatch [2]: (M) x (N) => (M,N) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M) Logical data + Tensor const& B, // (N) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + gemm(mma, + D, // (M,N) + make_tensor(A.data(), append<2>(A.layout())), // (M,1) + make_tensor(B.data(), append<2>(B.layout())), // (N,1) + C); // (M,N) +} + +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M) Logical data + Tensor const& B, // (V,N) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto M = size<1>(A); + auto N = size<1>(B); + // REGISTER .reuse OPTIMIZATIONS + // 64-bit traversal specialization -- serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } +#else + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } +#endif + } else + // 32-bit traversal specialization -- kinked serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major kinked serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; m += 2) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 2) ? N-1-n : n; + gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); + + if (m+1 < M) { + gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); + } + } + } +#else + // Col-major kinked serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; n += 2) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + // Kinked serpentine traversal for maximum register reuse + int ms = (n & 2) ? M-1-m : m; + gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); + + if (n+1 < N) { + gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); + } + } + } +#endif + } else + // 64-bit + 32-bit traversal order -- keep A (64-bit) in the outer loop and serpentine B + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } + } else + // 32-bit + 64-bit traversal order -- keep B (64-bit) in the outer loop and serpentine A + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } else + // Fallback to serpentine loop + { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_rmem::value && + BLayout::rank == 3 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) { + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } +} + +// +// Thread-Local Shared-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +// Dispatch [2]: (M) x (N) => (M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_smem::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_smem::value && + BLayout::rank == 3 && is_smem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto rA = MMA_Atom::make_fragment_A(A); + auto rB = MMA_Atom::make_fragment_B(B); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) + { + copy(A(_,_,k), rA(_,_,k)); + copy(B(_,_,k), rB(_,_,k)); + // Thread-level register gemm for k + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a1c53e7d26fa9beb15cbabc61418dfe2a3acbc4 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefer.hpp @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +// Infinite types that inherit from each other +template +struct prefer : prefer {}; + +template <> +struct prefer<0> {}; + +// Can be used to preferencially overload implementations +// Higher N in prefer have higher priority. + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp new file mode 100644 index 0000000000000000000000000000000000000000..265da128f169e03798d196fb05d0031f67fc8288 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/prefetch.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// Prefetch global tensors into L2 +// + +template +CUTE_HOST_DEVICE +void +cooperative_prefetch(uint32_t const& tid, + Tensor const& src) +{ + static_assert(is_gmem::value, "Expected global tensor for prefetch"); + + constexpr int V = decltype(max_common_vector(src, src))::value; + + if constexpr (V > 1) { + // L2 sector is 32B, default fetch granularity is 64B + using VecType = conditional_t<(V * sizeof_bits_v) < (FetchBytes * 8), + ArrayEngine, + uint8_t[FetchBytes] >; + + Tensor src_v = recast(src); + CUTE_UNROLL + for (int i = tid; i < size(src_v); i += NumThreads) { + prefetch(raw_pointer_cast(&src_v(i))); + } + } else { + CUTE_UNROLL + for (int i = tid; i < size(src); i += NumThreads) { + prefetch(raw_pointer_cast(&src(i))); + } + } +} + +template +CUTE_HOST_DEVICE +void +prefetch(Tensor const& src) +{ + return cooperative_prefetch<1>(0, src); +} + +// Prefetch with copy atom +namespace detail { + +template +constexpr bool has_prefetch = false; + +template +constexpr bool has_prefetch> = true; + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CopyType> const& atom, + Tensor const& src) +{ + if constexpr (detail::has_prefetch) { + using Prefetch_Traits = Copy_Traits; + using Prefetch_Atom = Copy_Atom; + Prefetch_Atom prefetch_atom{atom}; + //auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + Tensor dst = make_tensor(make_smem_ptr(nullptr), shape(src)); + return copy(prefetch_atom, src, dst); + } else { + return prefetch(src); + } +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Traits const& atom, + Tensor const& src) +{ + using SrcType = typename SrcEngine::value_type; + static_assert(is_gmem::value, "Expected global tensor for L2 prefetch"); + + auto tiler = max_common_layout(src, src); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + auto bulk_atom = Copy_Atom>, SrcType>{}; + + return prefetch(bulk_atom, logical_divide(src, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + return prefetch(static_cast const&>(atom), src); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f47becb18d96c7bcbaf3380ade318b5abaa4dc2d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_algorithms.hpp @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** Common algorithms on (hierarchical) tensors */ + +#pragma once + +#include +#include + +namespace cute +{ + +// +// for_each +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor const& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor&& tensor, UnaryOp&& op) +{ + return for_each(tensor, op); +} + +// +// transform +// + +// Similar to std::transform but does not return number of elements affected +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor, UnaryOp&& op) +{ + return transform(tensor, op); +} + +// Similar to std::transform transforms one tensors and assigns it to another +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor & tensor_out, + UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in); ++i) { + tensor_out(i) = op(tensor_in(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor && tensor_out, + UnaryOp&& op) +{ + return transform(tensor_in, tensor_out, op); +} + +// Similar to std::transform with a binary operation +// Takes two tensors as input and one tensor as output. +// Applies the binary_op to tensor_in1 and tensor_in2 and +// assigns it to tensor_out +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor & tensor_out, + BinaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in1); ++i) { + tensor_out(i) = op(tensor_in1(i), tensor_in2(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor && tensor_out, + BinaryOp&& op) +{ + return transform(tensor_in1, tensor_in2, tensor_out, op); +} + +namespace lazy { + +template +CUTE_HOST_DEVICE constexpr +auto +transform(cute::Tensor const& t, Fn const& fn) +{ + return cute::make_tensor(cute::make_transform_iter(fn, t.data()), t.layout()); +} + +} // end namespace lazy + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6f13735c214bec708cf80bf45f956e5d1df29d3 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tensor_reduce.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace cute +{ + +// Reduce @src tensor using binary reduction operator @op and initial value @init and return a scalar. +template +CUTE_HOST_DEVICE constexpr +T +reduce(Tensor const& src, T init, BinaryOp op = {}) +{ + for (auto i = 0; i < size(src); ++i) { + init = op(init, src(i)); + } + return init; +} + +// Reduce @src tensor RedMode using binary reduction operator @op and store the result in @dst tensor +// for each index in @dst/BatchMode. +// @pre @src tensor has rank 2 +// @pre size of @src batch mode is equal to size of @dst batch mode +template +CUTE_HOST_DEVICE constexpr +void +batch_reduce(Tensor const& src, // (RedMode, BatchMode) + Tensor & dst, // (BatchMode) + BinaryOp op = {}) +{ + // Precondition + CUTE_STATIC_ASSERT_V(rank(src) == Int<2>{}); + assert(size<1>(src) == size(dst)); + + for (int i = 0; i < size(dst); ++i) { + dst(i) = reduce(src(_,i), dst(i), op); + } +} + + +// Reduce @src tensor along selected modes specified in @target_profile using binary reduction operator @op +// and store the result in @dst tensor. @target_profile is a tuple where '_' indicates modes to keep and +// integers indicates modes to reduce. +// @pre @target_profile is compatible with @src layout +template +CUTE_HOST_DEVICE constexpr +void +logical_reduce(Tensor const& src, + Tensor & dst, + TargetProfile const& target_profile, + BinaryOp op = {}) +{ + // Precondition + assert(compatible(target_profile, shape(src))); + + auto diced_layout = dice(target_profile, src.layout()); + auto sliced_layout = slice(target_profile, src.layout()); + + auto red_mode = conditional_return{}>(Layout<_1,_0>{}, diced_layout); + auto batch_mode = conditional_return{}>(Layout<_1,_0>{}, sliced_layout); + + auto src_tensor = make_tensor(src.data(), make_layout(red_mode, batch_mode)); + + batch_reduce(src_tensor, dst, op); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp new file mode 100644 index 0000000000000000000000000000000000000000..311e10557f127b6f1f2771bf5379b75cc32cf824 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/algorithm/tuple_algorithms.hpp @@ -0,0 +1,1053 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +/// @file tuple_algorithms.hpp +/// @brief Common algorithms on (hierarchical) tuples +/// +/// Code guidelines and style preferences: +/// +/// For perfect forwarding, don't use std::forward, because it may not +/// be defined in device code when compiling with NVRTC. Instead, use +/// `static_cast(parameter_name)`. +/// +/// CuTe generally does not bother forwarding functions, as +/// reference-qualified member functions are rare in this code base. +/// +/// Throughout CUTLASS, cute::make_tuple always needs to be called +/// namespace-qualified, EVEN If inside the cute namespace and/or in +/// scope of a "using namespace cute" declaration. Otherwise, the +/// compiler may select std::make_tuple instead of cute::make_tuple, +/// due to argument-dependent lookup. + +namespace cute +{ + +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f, seq) +{ + return f(get(static_cast(t))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +// +// Transform Apply +// (t, f, g) => g(f(t_0),f(t_1),...) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T&& t, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)), + get(static_cast(t2)))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T&& t, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// For Each +// (t, f) => f(t_0),f(t_1),...,f(t_n) +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +for_each_leaf(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Transform +// (t, f) => (f(t_0),f(t_1),...,f(t_n)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1, t2); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// find and find_if +// + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return cute::C>{}; }, tuple_seq{}); + } else { + return cute::C{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find(T const& t, X const& x) +{ + return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false +} + +template +CUTE_HOST_DEVICE constexpr +auto +any_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return (false_type{} || ... || a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +all_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [] (auto... a) { return (true_type{} && ... && a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +none_of(T const& t, F&& f) +{ + return not any_of(t, f); +} + +// +// Filter +// (t, f) => +// + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T const& t, F&& f) +{ + return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, F&& f) +{ + return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +// +// Fold (Reduce, Accumulate) +// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) +// + +namespace detail { + +template +struct FoldAdaptor { + template + CUTE_HOST_DEVICE constexpr auto operator|(X&& x) { + auto r = fn_(val_, static_cast(x)); + return FoldAdaptor{fn_, r}; + } + Fn fn_; + Val val_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f, seq) +{ + return (FoldAdaptor{f,v} | ... | get(static_cast(t))).val_; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), v, f, tuple_seq{}); + } else { + return f(v, static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +fold_first(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), get<0>(t), f, make_range<1,tuple_size>::value>{}); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// front, back, take, select, unwrap +// + +// Get the first non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +front(T&& t) +{ + if constexpr (is_tuple>::value) { + return front(get<0>(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Get the last non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +back(T&& t) +{ + if constexpr (is_tuple>::value) { + constexpr int N = tuple_size>::value; + + // MSVC needs a bit of extra help here deducing return types. + // We help it by peeling off the nonrecursive case a level "early." + if constexpr (! is_tuple(static_cast(t)))>>::value) { + return get(static_cast(t)); + } else { + return back(get(static_cast(t))); + } + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Takes the elements in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(T const& t) +{ + if constexpr (E == -1) { + if constexpr (is_tuple::value) { + return take::value>(t); + } else { + return take(t); + } + } else + if constexpr (B <= E) { + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// Select tuple elements with given indices. +template +CUTE_HOST_DEVICE constexpr +auto +select(T const& t) +{ + return cute::make_tuple(get(t)...); +} + +// Wrap non-tuples into rank-1 tuples or forward +template +CUTE_HOST_DEVICE constexpr +auto +wrap(T const& t) +{ + if constexpr (is_tuple::value) { + return t; + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple +template +CUTE_HOST_DEVICE constexpr +auto +unwrap(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 1) { + return unwrap(get<0>(t)); + } else { + return t; + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Flatten and Unflatten +// + +template +struct is_flat : true_type {}; + +template +struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; + +// Flatten a hierarchical tuple to a tuple of depth one +// and wrap non-tuples into a rank-1 tuple. +template +CUTE_HOST_DEVICE constexpr +auto +flatten_to_tuple(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Flatten a hierarchical tuple to a tuple of depth one +// and leave non-tuple untouched. +template +CUTE_HOST_DEVICE constexpr +auto +flatten(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Unflatten a flat tuple into a hierarchical tuple +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + +// +// insert and remove and replace +// + +namespace detail { + +// Shortcut around cute::tuple_cat for common insert/remove/repeat cases +template +CUTE_HOST_DEVICE constexpr +auto +construct(T const& t, X const& x, seq, seq, seq) +{ + return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); +} + +} // end namespace detail + +// Insert x into the Nth position of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +insert(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Remove the Nth element of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +remove(T const& t) +{ + return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); +} + +// Replace the Nth element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); + } else { + static_assert(N == 0); + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the first element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_front(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the last element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_back(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs of tuple_size N +// + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make repeated Xs of rank N +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat(X const& x) +{ + if constexpr (N == 1) { + return x; + } else { + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs the same profile as tuple T +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat_like(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return repeat_like(a,x); }); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Group the elements [B,E) of a T into a single element +// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) +// => T<_1,_2,T<_3,_4>,_5,_6>{} +template +CUTE_HOST_DEVICE constexpr +auto +group(T const& t) +{ + if constexpr (not is_tuple::value) { + if constexpr (E == -1) { + return group(t); + } else { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); + } + } else + if constexpr (E == -1) { + return group::value>(t); + } else + if constexpr (B <= E) { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Extend a T to rank N by appending/prepending an element +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); + } else { + return cute::make_tuple(a, x); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + static_assert(N > 1); + return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); + } else { + return cute::make_tuple(x, a); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Inclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f, seq) +{ + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v_next + auto t_next = replace(t, v_next); + +#if 0 + std::cout << "ISCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + if constexpr (sizeof...(Is) == 0) { + return t_next; + } else { + return iscan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f) +{ + return detail::iscan(t, v, f, tuple_seq{}); +} + +// +// Exclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + // Replace I with v + return replace(t, v); + } else { + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v + auto t_next = replace(t, v); + +#if 0 + std::cout << "ESCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + // Recurse + return escan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f) +{ + return detail::escan(t, v, f, tuple_seq{}); +} + +// +// Zip (Transpose) +// + +// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input +// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip_(Ts const&... ts) +{ + return cute::make_tuple(get(ts)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t, seq, seq) +{ + static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); + return cute::make_tuple(zip_(get(t)...)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple>::value) { + return detail::zip(t, tuple_seq{}, tuple_seq>{}); + } else { + return cute::make_tuple(t); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Convenient to pass them in separately +template +CUTE_HOST_DEVICE constexpr +auto +zip(T0 const& t0, T1 const& t1, Ts const&... ts) +{ + return zip(cute::make_tuple(t0, t1, ts...)); +} + +// +// zip2_by -- A guided zip for rank-2 tuples +// Take a tuple like ((A,a),((B,b),(C,c)),d) +// and produce a tuple ((A,(B,C)),(a,(b,c),d)) +// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide, seq, seq) +{ + // zip2_by produces the modes like ((A,a),(B,b),...) + auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); + + // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) + return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), + cute::make_tuple(get<1>(get(split))..., get(t)...)); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide) +{ + if constexpr (is_tuple::value) { + constexpr int TR = tuple_size::value; + constexpr int GR = tuple_size::value; + static_assert(TR >= GR, "Mismatched ranks"); + return detail::zip2_by(t, guide, + make_range< 0, GR>{}, + make_range{}); + } else { + static_assert(tuple_size::value == 2, "Mismatched ranks"); + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +/// @return A tuple of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +auto +reverse(T const& t) +{ + if constexpr (is_tuple::value) { + return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq{}); + } else { + return t; + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0bcf19b0f218688edc6cc16a41a8e8968e06e611 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm100.hpp @@ -0,0 +1,108 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// + +#pragma once + +#include + + +namespace cute { + + +// +// Cluster launch utility +// +CUTE_HOST +bool +initialize_preferred_cluster_launch(void const* const kernel_function, + dim3 const& grid_dims, + dim3 const& cluster_dims_preferred, + dim3 const& cluster_dims_fallback) +{ + // + // Validate cluster_dims + // + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z <= 0 || + cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z > 32) { + std::cout << "Invalid preferred cluster dimensions: Attempting to init preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") [" << (cluster_dims_preferred.x * cluster_dims_preferred.y * cluster_dims_preferred.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total number of cluster cannot be greater than 32 (hardware requirement) + if (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z <= 0 || + cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z > 32) { + std::cout << "Invalid cluster dimensions: Attempting to init fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") [" << (cluster_dims_fallback.x * cluster_dims_fallback.y * cluster_dims_fallback.z) << "] which must be within (0,32]." << std::endl; + return false; + } + + // Total grid dimensions must be within (2^32, 2^16, 2^16) + if (grid_dims.y > (1 << 16) || grid_dims.z > (1 << 16)) { + std::cout << "Invalid grid dimensions: Attempting to init grid dimensions (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z + << ") which must be within (2^32, 2^16, 2^16)." << std::endl; + return false; + } + + // grid_dims should be divisible by cluster_dims_preferred + if (grid_dims.x % cluster_dims_preferred.x != 0 || + grid_dims.y % cluster_dims_preferred.y != 0 || + grid_dims.z % cluster_dims_preferred.z != 0) { + std::cout << "Invalid grid dimensions: Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z + << ") does not divide Grid (" << grid_dims.x << "," << grid_dims.y << "," << grid_dims.z << ")." << std::endl; + return false; + } + + // cluster_dims_preferred should be divisible by cluster_dims_fallback + if (cluster_dims_preferred.x % cluster_dims_fallback.x != 0 || + cluster_dims_preferred.y % cluster_dims_fallback.y != 0 || + cluster_dims_preferred.z % cluster_dims_fallback.z != 0) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") does not divide Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ")." << std::endl; + return false; + } + + // Both cluster dimenions should have the same depth + if (cluster_dims_preferred.z != cluster_dims_fallback.z) { + std::cout << "Invalid cluster dimensions: Fallback cluster (" << cluster_dims_fallback.x << "," << cluster_dims_fallback.y << "," << cluster_dims_fallback.z + << ") and Preferred cluster (" << cluster_dims_preferred.x << "," << cluster_dims_preferred.y << "," << cluster_dims_preferred.z << ") does not have the same depth." << std::endl; + return false; + } + + return true; +} +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..524a47efb11b04465e22541b8c2c12c6f4720af2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/cluster_sm90.hpp @@ -0,0 +1,246 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +# define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_arrive() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_wait() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_sync() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. + return gridDim; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. + return blockIdx; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {0,0,0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {1,1,1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTE_DEVICE uint32_t block_rank_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType +elect_one_leader_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + +// Store value to remote shared memory in the cluster +CUTE_DEVICE +void +store_shared_remote(uint32_t value, uint32_t smem_addr, uint32_t mbarrier_addr, uint32_t dst_cta_rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t dsmem_addr = set_block_rank(smem_addr, dst_cta_rank); + uint32_t remote_barrier_addr = set_block_rank(mbarrier_addr, dst_cta_rank); + asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];" + : : "r"(dsmem_addr), "r"(value), "r"(remote_barrier_addr)); +#endif +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7ce05f30d3acb0db4b9c9f1531e1e9666432db61 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/config.hpp @@ -0,0 +1,216 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTLASS_ARCH_MMA_SMxx_ENABLED + +// MMA SM90A +#if defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) +# define CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + + +// SM110 specific configs +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM110F_ENABLED)) +# define CUTE_ARCH_TMA_SM90_ENABLED +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +# define CUTE_ARCH_STSM_SM90_ENABLED +# define CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM110A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED)) +# define CUTE_ARCH_TCGEN05_S8_MMA_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED)) +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) + // Enable CuTe MMA Atoms +# define CUTE_ARCH_FFMA2_SM100_ENABLED + // Enable f32x2 PTX generation +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# define CUTE_ARCH_MMA_SM120_ENABLED +# define CUTE_ARCH_TMA_SM120_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120A_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM121_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121A_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_F8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF8F6F4_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED +# define CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED +# endif +#endif + +#if defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +# define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +# define CUTE_ARCH_TCGEN05_TMEM_ENABLED +# define CUTE_ARCH_TMA_SM100_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# define CUTE_ARCH_LDSM_SM100A_ENABLED +# define CUTE_ARCH_STSM_SM100A_ENABLED +#endif + +#if (defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM101A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM101F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM103A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM103F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM120A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM120F_ENABLED) ||\ + defined(CUTLASS_ARCH_MMA_SM121A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM121F_ENABLED)) +# if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9)) +# define CUTE_ARCH_LOAD256_SM100A_ENABLED +# define CUTE_ARCH_STORE256_SM100A_ENABLED +# endif +#endif + +// {add, mul, fma}.f32x2 PTX +#if defined(CUTLASS_ARCH_MMA_SM100A_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) + #define CUTE_ARCH_FLOAT2_MATH_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_SM103_ENABLED) || defined(CUTLASS_ARCH_MMA_SM100F_ENABLED) +# define CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED +#endif + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b62fa91f9937afe5dd1dcedfd983743dd6ba076 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Direct Copy for any specific types +// + +template +struct UniversalCopy +{ + using SRegisters = S[1]; + using DRegisters = D[1]; + + // Sanity + static_assert(sizeof_bits_v >= 8); + static_assert(sizeof_bits_v >= 8); + + CUTE_HOST_DEVICE static constexpr void + copy(S const& src, + D & dst) + { + dst = src; + } +}; + +// +// Placeholder for the copy algorithm's stronger auto-vectorizing behavior +// that assumes alignment of pointers and dynamic layouts up to MaxVecBits +// + +template +struct AutoVectorizingCopyWithAssumedAlignment + : UniversalCopy> +{ + static_assert(MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be 8 or 16 or 32 or 64 or 128 for alignment and performance."); +}; + +// +// AutoVectorizingCopy alias assumes maximal alignment of pointers and dynamic strides. +// If this is not the case then AutoVectorizingCopyWithAssumedAlignment should be used instead +// + +using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + +// +// DefaultCopy alias does not assume alignment of pointers or dynamic strides. +// + +using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; + +// +// Copy policy automatically selecting between +// UniversalCopy and cp.async , based on type and memory space. +// +struct AutoCopyAsync {}; + +// +// Global memory prefetch into L2 +// + +CUTE_HOST_DEVICE static void +prefetch(void const* gmem_ptr) +{ +#if defined(__CUDA_ARCH__) + asm volatile("prefetch.global.L2 [%0];\n" : : "l"(gmem_ptr) : "memory"); +#endif +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f8a9a67b1d86dc27a30eee2a133ff0b6b696395b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100.hpp @@ -0,0 +1,7615 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Global Memory Load and Store PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_LOAD_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint256_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint256_t const& gmem_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { + #if defined(CUTE_ARCH_LOAD256_SM100A_ENABLED) + asm volatile("ld.global.L1::no_allocate.v8.f32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "l"(&gmem_addr) ); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use LOAD.256 without CUTE_ARCH_LOAD256_SM100A_ENABLED."); + #endif + } +}; + +struct SM100_STORE_256bit_CACHE_NOALLOCATION +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint256_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint256_t& gmem_addr) + { + #if defined(CUTE_ARCH_STORE256_SM100A_ENABLED) + asm volatile("st.global.L1::no_allocate.v8.f32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "l"(&gmem_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), "r"(src4), "r"(src5), "r"(src6), "r"(src7)); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use stg.256 without CUTE_ARCH_STORE256_SM100A_ENABLED."); + #endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// LDSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x1.trans.shared.b8 {%0, %1}, [%2];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)) + : "r"(smem_int_ptr)); + // RefLayout of ldmatrix.m16n16.x1.trans won't match stmatrix.m16n8.x2.trans without additional transformations + // Do this here so we don't need to add an additional reg to reg copy at the collective layer + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t tmp0, tmp1, tmp2, tmp3; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m16n16.x2.trans.shared.b8 {%0, %1, %2, %3}, [%4];\n" + : "=r"(reinterpret_cast(tmp0)), "=r"(reinterpret_cast(tmp1)), + "=r"(reinterpret_cast(tmp2)), "=r"(reinterpret_cast(tmp3)) + : "r"(smem_int_ptr)); + uchar4& tmp0_ = reinterpret_cast(tmp0); + uchar4& tmp1_ = reinterpret_cast(tmp1); + uchar4& tmp2_ = reinterpret_cast(tmp2); + uchar4& tmp3_ = reinterpret_cast(tmp3); + uchar4 dst0_{tmp0_.x, tmp0_.y, tmp1_.x, tmp1_.y}; + uchar4 dst1_{tmp0_.z, tmp0_.w, tmp1_.z, tmp1_.w}; + uchar4 dst2_{tmp2_.x, tmp2_.y, tmp3_.x, tmp3_.y}; + uchar4 dst3_{tmp2_.z, tmp2_.w, tmp3_.z, tmp3_.w}; + dst0 = reinterpret_cast(dst0_); + dst1 = reinterpret_cast(dst1_); + dst2 = reinterpret_cast(dst2_); + dst3 = reinterpret_cast(dst3_); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +struct SM100_SU4_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b4x16_p64 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x1.shared.b8x16.b6x16_p32 {%0}, [%1];\n" + : "=r"(dst0) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b4x16_p64 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x2.shared.b8x16.b6x16_p32 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU4_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b4x16_p64 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_SU6_DU8x16_x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.m8n16.x4.shared.b8x16.b6x16_p32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// STSM PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x4_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x1.trans.shared.b8 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x8_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x2.trans.shared.b8 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_U8x16_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM100A_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.m16n8.x4.trans.shared.b8 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM100A_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::UTCCP { + +// 128 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_128dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 128 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_128dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +struct SM100_UTCCP_128dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.128x128b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + + +// 4 data path lanes, 256-bit pattern, 1cta mode +struct SM100_UTCCP_4dp256bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4 data path lanes, 256-bit pattern, 2cta mode +struct SM100_UTCCP_4dp256bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.4x256b [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 1cta mode +struct SM100_UTCCP_4x32dp128bit_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 4x32 data path lanes (broadcast), 128-bit pattern, 2cta mode +struct SM100_UTCCP_4x32dp128bit_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.32x128b.warpx4 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 1cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast like 4x32dp), 128-bit pattern, 2cta mode +struct SM100_UTCCP_2x64dp128bitlw0213_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::02_13 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 1cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_1cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::1.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +// 2x64 data path lanes (broadcast seperately in upper and lower 64dp), 128-bit pattern, 2cta mode +// data_row[0:31] -> DP[0:63] +// data_row[32:63] -> DP[64:127] +struct SM100_UTCCP_2x64dp128bitlw0123_2cta +{ + using SRegisters = uint64_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint64_t const& src_addr, uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.cp.cta_group::2.64x128b.warpx2::01_23 [%0], %1;" + : + : "r"(dst_addr) "l"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use UTCCP without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +} // end namespace SM100::TMEM::UTCCP + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::LOAD { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM LOAD PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp256b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x1.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp256b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x2.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp256b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp256b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp256b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp256b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp256b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x256b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp128b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x1.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp128b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp128b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x4.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp128b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp128b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp128b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp128b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp128b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x128b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp64b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp64b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp64b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp64b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp64b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp64b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp64b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp64b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x64b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.b32" + "{%0}," + "[%1], 1;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x1.pack::16b.b32" + "{%0}," + "[%1], 2;\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_16dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.b32" + "{%0, %1}," + "[%2], 2;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x2.pack::16b.b32" + "{%0, %1}," + "[%2], 4;\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_16dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.b32" + "{%0, %1, %2, %3}," + "[%4], 4;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_16dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 8;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8], 16;\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_16dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 16;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_16dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 32;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_16dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 64;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64], 128;\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_16dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 128;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_16dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.16x32bx2.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128], 256;\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_LOAD_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x1.pack::16b.b32" + "{%0}," + "[%1];\n" + : "=r"(dst0) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_LOAD_32dp32b2x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b2x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x2.pack::16b.b32" + "{%0, %1}," + "[%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_LOAD_32dp32b4x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b4x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x4.pack::16b.b32" + "{%0, %1, %2, %3}," + "[%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_LOAD_32dp32b8x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b8x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3, + uint32_t& dst4, uint32_t& dst5, uint32_t& dst6, uint32_t& dst7) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x8.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7}," + "[%8];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3), + "=r"(dst4), "=r"(dst5), "=r"(dst6), "=r"(dst7) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_LOAD_32dp32b16x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b16x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x16.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15}," + "[%16];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_LOAD_32dp32b32x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b32x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x32.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31}," + "[%32];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_LOAD_32dp32b64x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b64x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst00, uint32_t& dst01, uint32_t& dst02, uint32_t& dst03, + uint32_t& dst04, uint32_t& dst05, uint32_t& dst06, uint32_t& dst07, + uint32_t& dst08, uint32_t& dst09, uint32_t& dst10, uint32_t& dst11, + uint32_t& dst12, uint32_t& dst13, uint32_t& dst14, uint32_t& dst15, + uint32_t& dst16, uint32_t& dst17, uint32_t& dst18, uint32_t& dst19, + uint32_t& dst20, uint32_t& dst21, uint32_t& dst22, uint32_t& dst23, + uint32_t& dst24, uint32_t& dst25, uint32_t& dst26, uint32_t& dst27, + uint32_t& dst28, uint32_t& dst29, uint32_t& dst30, uint32_t& dst31, + uint32_t& dst32, uint32_t& dst33, uint32_t& dst34, uint32_t& dst35, + uint32_t& dst36, uint32_t& dst37, uint32_t& dst38, uint32_t& dst39, + uint32_t& dst40, uint32_t& dst41, uint32_t& dst42, uint32_t& dst43, + uint32_t& dst44, uint32_t& dst45, uint32_t& dst46, uint32_t& dst47, + uint32_t& dst48, uint32_t& dst49, uint32_t& dst50, uint32_t& dst51, + uint32_t& dst52, uint32_t& dst53, uint32_t& dst54, uint32_t& dst55, + uint32_t& dst56, uint32_t& dst57, uint32_t& dst58, uint32_t& dst59, + uint32_t& dst60, uint32_t& dst61, uint32_t& dst62, uint32_t& dst63) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x64.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63}," + "[%64];\n" + : "=r"(dst00), "=r"(dst01), "=r"(dst02), "=r"(dst03), + "=r"(dst04), "=r"(dst05), "=r"(dst06), "=r"(dst07), + "=r"(dst08), "=r"(dst09), "=r"(dst10), "=r"(dst11), + "=r"(dst12), "=r"(dst13), "=r"(dst14), "=r"(dst15), + "=r"(dst16), "=r"(dst17), "=r"(dst18), "=r"(dst19), + "=r"(dst20), "=r"(dst21), "=r"(dst22), "=r"(dst23), + "=r"(dst24), "=r"(dst25), "=r"(dst26), "=r"(dst27), + "=r"(dst28), "=r"(dst29), "=r"(dst30), "=r"(dst31), + "=r"(dst32), "=r"(dst33), "=r"(dst34), "=r"(dst35), + "=r"(dst36), "=r"(dst37), "=r"(dst38), "=r"(dst39), + "=r"(dst40), "=r"(dst41), "=r"(dst42), "=r"(dst43), + "=r"(dst44), "=r"(dst45), "=r"(dst46), "=r"(dst47), + "=r"(dst48), "=r"(dst49), "=r"(dst50), "=r"(dst51), + "=r"(dst52), "=r"(dst53), "=r"(dst54), "=r"(dst55), + "=r"(dst56), "=r"(dst57), "=r"(dst58), "=r"(dst59), + "=r"(dst60), "=r"(dst61), "=r"(dst62), "=r"(dst63) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_LOAD_32dp32b128x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, packed 16b read +struct SM100_TMEM_LOAD_32dp32b128x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src_addr, + uint32_t& dst000, uint32_t& dst001, uint32_t& dst002, uint32_t& dst003, + uint32_t& dst004, uint32_t& dst005, uint32_t& dst006, uint32_t& dst007, + uint32_t& dst008, uint32_t& dst009, uint32_t& dst010, uint32_t& dst011, + uint32_t& dst012, uint32_t& dst013, uint32_t& dst014, uint32_t& dst015, + uint32_t& dst016, uint32_t& dst017, uint32_t& dst018, uint32_t& dst019, + uint32_t& dst020, uint32_t& dst021, uint32_t& dst022, uint32_t& dst023, + uint32_t& dst024, uint32_t& dst025, uint32_t& dst026, uint32_t& dst027, + uint32_t& dst028, uint32_t& dst029, uint32_t& dst030, uint32_t& dst031, + uint32_t& dst032, uint32_t& dst033, uint32_t& dst034, uint32_t& dst035, + uint32_t& dst036, uint32_t& dst037, uint32_t& dst038, uint32_t& dst039, + uint32_t& dst040, uint32_t& dst041, uint32_t& dst042, uint32_t& dst043, + uint32_t& dst044, uint32_t& dst045, uint32_t& dst046, uint32_t& dst047, + uint32_t& dst048, uint32_t& dst049, uint32_t& dst050, uint32_t& dst051, + uint32_t& dst052, uint32_t& dst053, uint32_t& dst054, uint32_t& dst055, + uint32_t& dst056, uint32_t& dst057, uint32_t& dst058, uint32_t& dst059, + uint32_t& dst060, uint32_t& dst061, uint32_t& dst062, uint32_t& dst063, + uint32_t& dst064, uint32_t& dst065, uint32_t& dst066, uint32_t& dst067, + uint32_t& dst068, uint32_t& dst069, uint32_t& dst070, uint32_t& dst071, + uint32_t& dst072, uint32_t& dst073, uint32_t& dst074, uint32_t& dst075, + uint32_t& dst076, uint32_t& dst077, uint32_t& dst078, uint32_t& dst079, + uint32_t& dst080, uint32_t& dst081, uint32_t& dst082, uint32_t& dst083, + uint32_t& dst084, uint32_t& dst085, uint32_t& dst086, uint32_t& dst087, + uint32_t& dst088, uint32_t& dst089, uint32_t& dst090, uint32_t& dst091, + uint32_t& dst092, uint32_t& dst093, uint32_t& dst094, uint32_t& dst095, + uint32_t& dst096, uint32_t& dst097, uint32_t& dst098, uint32_t& dst099, + uint32_t& dst100, uint32_t& dst101, uint32_t& dst102, uint32_t& dst103, + uint32_t& dst104, uint32_t& dst105, uint32_t& dst106, uint32_t& dst107, + uint32_t& dst108, uint32_t& dst109, uint32_t& dst110, uint32_t& dst111, + uint32_t& dst112, uint32_t& dst113, uint32_t& dst114, uint32_t& dst115, + uint32_t& dst116, uint32_t& dst117, uint32_t& dst118, uint32_t& dst119, + uint32_t& dst120, uint32_t& dst121, uint32_t& dst122, uint32_t& dst123, + uint32_t& dst124, uint32_t& dst125, uint32_t& dst126, uint32_t& dst127) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.ld.sync.aligned.32x32b.x128.pack::16b.b32" + "{%0, %1, %2, %3," + "%4, %5, %6, %7," + "%8, %9, %10, %11," + "%12, %13, %14, %15," + "%16, %17, %18, %19," + "%20, %21, %22, %23," + "%24, %25, %26, %27," + "%28, %29, %30, %31," + "%32, %33, %34, %35," + "%36, %37, %38, %39," + "%40, %41, %42, %43," + "%44, %45, %46, %47," + "%48, %49, %50, %51," + "%52, %53, %54, %55," + "%56, %57, %58, %59," + "%60, %61, %62, %63," + "%64, %65, %66, %67," + "%68, %69, %70, %71," + "%72, %73, %74, %75," + "%76, %77, %78, %79," + "%80, %81, %82, %83," + "%84, %85, %86, %87," + "%88, %89, %90, %91," + "%92, %93, %94, %95," + "%96, %97, %98, %99," + "%100, %101, %102, %103," + "%104, %105, %106, %107," + "%108, %109, %110, %111," + "%112, %113, %114, %115," + "%116, %117, %118, %119," + "%120, %121, %122, %123," + "%124, %125, %126, %127}," + "[%128];\n" + : "=r"(dst000), "=r"(dst001), "=r"(dst002), "=r"(dst003), + "=r"(dst004), "=r"(dst005), "=r"(dst006), "=r"(dst007), + "=r"(dst008), "=r"(dst009), "=r"(dst010), "=r"(dst011), + "=r"(dst012), "=r"(dst013), "=r"(dst014), "=r"(dst015), + "=r"(dst016), "=r"(dst017), "=r"(dst018), "=r"(dst019), + "=r"(dst020), "=r"(dst021), "=r"(dst022), "=r"(dst023), + "=r"(dst024), "=r"(dst025), "=r"(dst026), "=r"(dst027), + "=r"(dst028), "=r"(dst029), "=r"(dst030), "=r"(dst031), + "=r"(dst032), "=r"(dst033), "=r"(dst034), "=r"(dst035), + "=r"(dst036), "=r"(dst037), "=r"(dst038), "=r"(dst039), + "=r"(dst040), "=r"(dst041), "=r"(dst042), "=r"(dst043), + "=r"(dst044), "=r"(dst045), "=r"(dst046), "=r"(dst047), + "=r"(dst048), "=r"(dst049), "=r"(dst050), "=r"(dst051), + "=r"(dst052), "=r"(dst053), "=r"(dst054), "=r"(dst055), + "=r"(dst056), "=r"(dst057), "=r"(dst058), "=r"(dst059), + "=r"(dst060), "=r"(dst061), "=r"(dst062), "=r"(dst063), + "=r"(dst064), "=r"(dst065), "=r"(dst066), "=r"(dst067), + "=r"(dst068), "=r"(dst069), "=r"(dst070), "=r"(dst071), + "=r"(dst072), "=r"(dst073), "=r"(dst074), "=r"(dst075), + "=r"(dst076), "=r"(dst077), "=r"(dst078), "=r"(dst079), + "=r"(dst080), "=r"(dst081), "=r"(dst082), "=r"(dst083), + "=r"(dst084), "=r"(dst085), "=r"(dst086), "=r"(dst087), + "=r"(dst088), "=r"(dst089), "=r"(dst090), "=r"(dst091), + "=r"(dst092), "=r"(dst093), "=r"(dst094), "=r"(dst095), + "=r"(dst096), "=r"(dst097), "=r"(dst098), "=r"(dst099), + "=r"(dst100), "=r"(dst101), "=r"(dst102), "=r"(dst103), + "=r"(dst104), "=r"(dst105), "=r"(dst106), "=r"(dst107), + "=r"(dst108), "=r"(dst109), "=r"(dst110), "=r"(dst111), + "=r"(dst112), "=r"(dst113), "=r"(dst114), "=r"(dst115), + "=r"(dst116), "=r"(dst117), "=r"(dst118), "=r"(dst119), + "=r"(dst120), "=r"(dst121), "=r"(dst122), "=r"(dst123), + "=r"(dst124), "=r"(dst125), "=r"(dst126), "=r"(dst127) + : "r"(src_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_LOAD without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::LOAD + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::STORE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM STORE PTX definitions +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp256b1x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b1x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp256b2x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b2x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp256b4x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b4x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp256b8x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b8x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp256b16x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b16x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp256b32x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 256-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp256b32x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x256b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp128b1x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b1x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x1.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp128b2x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b2x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp128b4x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b4x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp128b8x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b8x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp128b16x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b16x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp128b32x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b32x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp128b64x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 128-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp128b64x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x128b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp64b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp64b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp64b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp64b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp64b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp64b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp64b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp64b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 64-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp64b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x64b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_16dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.b32" + "[%0] , 1," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x1.unpack::16b.b32" + "[%0] , 2," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_16dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.b32" + "[%0] , 2," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x2.unpack::16b.b32" + "[%0] , 4," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_16dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.b32" + "[%0] , 4," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x4.unpack::16b.b32" + "[%0] , 8," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_16dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.b32" + "[%0] , 8," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x8.unpack::16b.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_16dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.b32" + "[%0] , 16," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x16.unpack::16b.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_16dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.b32" + "[%0] , 32," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x32.unpack::16b.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_16dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.b32" + "[%0] , 64," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x64.unpack::16b.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_16dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.b32" + "[%0] , 128," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 16 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_16dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.16x32bx2.x128.unpack::16b.b32" + "[%0] , 256," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times +struct SM100_TMEM_STORE_32dp32b1x +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 1 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b1x_16b +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x1.unpack::16b.b32" + "[%0]," + "{%1};\n" + : + : "r"(dst_addr), "r"(src0) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times +struct SM100_TMEM_STORE_32dp32b2x +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 2 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b2x_16b +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x2.unpack::16b.b32" + "[%0]," + "{%1, %2};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times +struct SM100_TMEM_STORE_32dp32b4x +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 4 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b4x_16b +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x4.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times +struct SM100_TMEM_STORE_32dp32b8x +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 8 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b8x_16b +{ + using SRegisters = uint32_t[8]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint32_t const& src4, uint32_t const& src5, uint32_t const& src6, uint32_t const& src7, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x8.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8};\n" + : + : "r"(dst_addr), "r"(src0), "r"(src1), "r"(src2), "r"(src3), + "r"(src4), "r"(src5), "r"(src6), "r"(src7) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times +struct SM100_TMEM_STORE_32dp32b16x +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 16 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b16x_16b +{ + using SRegisters = uint32_t[16]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x16.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times +struct SM100_TMEM_STORE_32dp32b32x +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 32 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b32x_16b +{ + using SRegisters = uint32_t[32]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x32.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times +struct SM100_TMEM_STORE_32dp32b64x +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 64 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b64x_16b +{ + using SRegisters = uint32_t[64]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src00, uint32_t const& src01, uint32_t const& src02, uint32_t const& src03, + uint32_t const& src04, uint32_t const& src05, uint32_t const& src06, uint32_t const& src07, + uint32_t const& src08, uint32_t const& src09, uint32_t const& src10, uint32_t const& src11, + uint32_t const& src12, uint32_t const& src13, uint32_t const& src14, uint32_t const& src15, + uint32_t const& src16, uint32_t const& src17, uint32_t const& src18, uint32_t const& src19, + uint32_t const& src20, uint32_t const& src21, uint32_t const& src22, uint32_t const& src23, + uint32_t const& src24, uint32_t const& src25, uint32_t const& src26, uint32_t const& src27, + uint32_t const& src28, uint32_t const& src29, uint32_t const& src30, uint32_t const& src31, + uint32_t const& src32, uint32_t const& src33, uint32_t const& src34, uint32_t const& src35, + uint32_t const& src36, uint32_t const& src37, uint32_t const& src38, uint32_t const& src39, + uint32_t const& src40, uint32_t const& src41, uint32_t const& src42, uint32_t const& src43, + uint32_t const& src44, uint32_t const& src45, uint32_t const& src46, uint32_t const& src47, + uint32_t const& src48, uint32_t const& src49, uint32_t const& src50, uint32_t const& src51, + uint32_t const& src52, uint32_t const& src53, uint32_t const& src54, uint32_t const& src55, + uint32_t const& src56, uint32_t const& src57, uint32_t const& src58, uint32_t const& src59, + uint32_t const& src60, uint32_t const& src61, uint32_t const& src62, uint32_t const& src63, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x64.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64};\n" + : + : "r"(dst_addr), "r"(src00), "r"(src01), "r"(src02), "r"(src03), + "r"(src04), "r"(src05), "r"(src06), "r"(src07), + "r"(src08), "r"(src09), "r"(src10), "r"(src11), + "r"(src12), "r"(src13), "r"(src14), "r"(src15), + "r"(src16), "r"(src17), "r"(src18), "r"(src19), + "r"(src20), "r"(src21), "r"(src22), "r"(src23), + "r"(src24), "r"(src25), "r"(src26), "r"(src27), + "r"(src28), "r"(src29), "r"(src30), "r"(src31), + "r"(src32), "r"(src33), "r"(src34), "r"(src35), + "r"(src36), "r"(src37), "r"(src38), "r"(src39), + "r"(src40), "r"(src41), "r"(src42), "r"(src43), + "r"(src44), "r"(src45), "r"(src46), "r"(src47), + "r"(src48), "r"(src49), "r"(src50), "r"(src51), + "r"(src52), "r"(src53), "r"(src54), "r"(src55), + "r"(src56), "r"(src57), "r"(src58), "r"(src59), + "r"(src60), "r"(src61), "r"(src62), "r"(src63) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times +struct SM100_TMEM_STORE_32dp32b128x +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// 32 data path lanes, 32-bit pattern, repeated 128 times, expand 16b write +struct SM100_TMEM_STORE_32dp32b128x_16b +{ + using SRegisters = uint32_t[128]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src000, uint32_t const& src001, uint32_t const& src002, uint32_t const& src003, + uint32_t const& src004, uint32_t const& src005, uint32_t const& src006, uint32_t const& src007, + uint32_t const& src008, uint32_t const& src009, uint32_t const& src010, uint32_t const& src011, + uint32_t const& src012, uint32_t const& src013, uint32_t const& src014, uint32_t const& src015, + uint32_t const& src016, uint32_t const& src017, uint32_t const& src018, uint32_t const& src019, + uint32_t const& src020, uint32_t const& src021, uint32_t const& src022, uint32_t const& src023, + uint32_t const& src024, uint32_t const& src025, uint32_t const& src026, uint32_t const& src027, + uint32_t const& src028, uint32_t const& src029, uint32_t const& src030, uint32_t const& src031, + uint32_t const& src032, uint32_t const& src033, uint32_t const& src034, uint32_t const& src035, + uint32_t const& src036, uint32_t const& src037, uint32_t const& src038, uint32_t const& src039, + uint32_t const& src040, uint32_t const& src041, uint32_t const& src042, uint32_t const& src043, + uint32_t const& src044, uint32_t const& src045, uint32_t const& src046, uint32_t const& src047, + uint32_t const& src048, uint32_t const& src049, uint32_t const& src050, uint32_t const& src051, + uint32_t const& src052, uint32_t const& src053, uint32_t const& src054, uint32_t const& src055, + uint32_t const& src056, uint32_t const& src057, uint32_t const& src058, uint32_t const& src059, + uint32_t const& src060, uint32_t const& src061, uint32_t const& src062, uint32_t const& src063, + uint32_t const& src064, uint32_t const& src065, uint32_t const& src066, uint32_t const& src067, + uint32_t const& src068, uint32_t const& src069, uint32_t const& src070, uint32_t const& src071, + uint32_t const& src072, uint32_t const& src073, uint32_t const& src074, uint32_t const& src075, + uint32_t const& src076, uint32_t const& src077, uint32_t const& src078, uint32_t const& src079, + uint32_t const& src080, uint32_t const& src081, uint32_t const& src082, uint32_t const& src083, + uint32_t const& src084, uint32_t const& src085, uint32_t const& src086, uint32_t const& src087, + uint32_t const& src088, uint32_t const& src089, uint32_t const& src090, uint32_t const& src091, + uint32_t const& src092, uint32_t const& src093, uint32_t const& src094, uint32_t const& src095, + uint32_t const& src096, uint32_t const& src097, uint32_t const& src098, uint32_t const& src099, + uint32_t const& src100, uint32_t const& src101, uint32_t const& src102, uint32_t const& src103, + uint32_t const& src104, uint32_t const& src105, uint32_t const& src106, uint32_t const& src107, + uint32_t const& src108, uint32_t const& src109, uint32_t const& src110, uint32_t const& src111, + uint32_t const& src112, uint32_t const& src113, uint32_t const& src114, uint32_t const& src115, + uint32_t const& src116, uint32_t const& src117, uint32_t const& src118, uint32_t const& src119, + uint32_t const& src120, uint32_t const& src121, uint32_t const& src122, uint32_t const& src123, + uint32_t const& src124, uint32_t const& src125, uint32_t const& src126, uint32_t const& src127, + uint32_t const& dst_addr) + { +#if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile ("tcgen05.st.sync.aligned.32x32b.x128.unpack::16b.b32" + "[%0]," + "{%1, %2, %3, %4," + "%5, %6, %7, %8," + "%9, %10, %11, %12," + "%13, %14, %15, %16," + "%17, %18, %19, %20," + "%21, %22, %23, %24," + "%25, %26, %27, %28," + "%29, %30, %31, %32," + "%33, %34, %35, %36," + "%37, %38, %39, %40," + "%41, %42, %43, %44," + "%45, %46, %47, %48," + "%49, %50, %51, %52," + "%53, %54, %55, %56," + "%57, %58, %59, %60," + "%61, %62, %63, %64," + "%65, %66, %67, %68," + "%69, %70, %71, %72," + "%73, %74, %75, %76," + "%77, %78, %79, %80," + "%81, %82, %83, %84," + "%85, %86, %87, %88," + "%89, %90, %91, %92," + "%93, %94, %95, %96," + "%97, %98, %99, %100," + "%101, %102, %103, %104," + "%105, %106, %107, %108," + "%109, %110, %111, %112," + "%113, %114, %115, %116," + "%117, %118, %119, %120," + "%121, %122, %123, %124," + "%125, %126, %127, %128};\n" + : + : "r"(dst_addr), "r"(src000), "r"(src001), "r"(src002), "r"(src003), + "r"(src004), "r"(src005), "r"(src006), "r"(src007), + "r"(src008), "r"(src009), "r"(src010), "r"(src011), + "r"(src012), "r"(src013), "r"(src014), "r"(src015), + "r"(src016), "r"(src017), "r"(src018), "r"(src019), + "r"(src020), "r"(src021), "r"(src022), "r"(src023), + "r"(src024), "r"(src025), "r"(src026), "r"(src027), + "r"(src028), "r"(src029), "r"(src030), "r"(src031), + "r"(src032), "r"(src033), "r"(src034), "r"(src035), + "r"(src036), "r"(src037), "r"(src038), "r"(src039), + "r"(src040), "r"(src041), "r"(src042), "r"(src043), + "r"(src044), "r"(src045), "r"(src046), "r"(src047), + "r"(src048), "r"(src049), "r"(src050), "r"(src051), + "r"(src052), "r"(src053), "r"(src054), "r"(src055), + "r"(src056), "r"(src057), "r"(src058), "r"(src059), + "r"(src060), "r"(src061), "r"(src062), "r"(src063), + "r"(src064), "r"(src065), "r"(src066), "r"(src067), + "r"(src068), "r"(src069), "r"(src070), "r"(src071), + "r"(src072), "r"(src073), "r"(src074), "r"(src075), + "r"(src076), "r"(src077), "r"(src078), "r"(src079), + "r"(src080), "r"(src081), "r"(src082), "r"(src083), + "r"(src084), "r"(src085), "r"(src086), "r"(src087), + "r"(src088), "r"(src089), "r"(src090), "r"(src091), + "r"(src092), "r"(src093), "r"(src094), "r"(src095), + "r"(src096), "r"(src097), "r"(src098), "r"(src099), + "r"(src100), "r"(src101), "r"(src102), "r"(src103), + "r"(src104), "r"(src105), "r"(src106), "r"(src107), + "r"(src108), "r"(src109), "r"(src110), "r"(src111), + "r"(src112), "r"(src113), "r"(src114), "r"(src115), + "r"(src116), "r"(src117), "r"(src118), "r"(src119), + "r"(src120), "r"(src121), "r"(src122), "r"(src123), + "r"(src124), "r"(src125), "r"(src126), "r"(src127) ); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMEM_STORE without CUTE_ARCH_TCGEN05_TMEM_ENABLED."); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ef6cfe01026946101e9468fd45da549c9d5bcb86 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm100_tma.hpp @@ -0,0 +1,666 @@ +/*************************************************************************************************** + * Copyright (c) 2020 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#include + +#include +#include +#include "cutlass/arch/synclog.hpp" + +namespace cute +{ + +constexpr uint32_t Sm100MmaPeerBitMask = 0xFEFFFFFF; +constexpr uint64_t Sm100MemDescDefault = uint64_t(0x1000000000000000); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// UTMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy([[maybe_unused]] void const* desc_ptr, [[maybe_unused]] uint64_t* mbar_ptr, [[maybe_unused]] uint64_t cache_hint, + [[maybe_unused]] void * smem_ptr, + [[maybe_unused]] int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) +uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM0_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM100_TMA_2SM_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM100_TMA_2SM_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM100_TMA_2SM_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM100_TMA_2SM_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM100_TMA_2SM_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +struct SM100_TMA_2SM_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9, %10;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, + uint16_t const& offset_h, + uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM100_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Executed by both CTAs. Set peer bit to 0 so that the + // transaction bytes will update CTA0's barrier. + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr) & Sm100MmaPeerBitMask; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.im2col.cta_group::2.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11, %12;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask), + "l"(Sm100MemDescDefault) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM100_ENABLED."); +#endif + } +}; + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12518fccd0781f0444faf771c929252ba952face --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm50.hpp @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 500 + #define CUTE_ARCH_WARP_SHUFFLE_ENABLED 1 +#endif + +namespace cute +{ +// Shuffle data between thread pair (0, 1), (2, 3), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR1 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = src0; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 1); + + uint32_t x1 = src1; + uint32_t y1 = __shfl_xor_sync(0xffffffff, x1, 1); + + if (threadIdx.x % 2 == 0) { + dst1 = y0; + } + else { + dst0 = y1; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + +// Shuffle data between thread pair (0, 4), (1, 5), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR4 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = threadIdx.x & 4 ? src0 : src1; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 4); + + // Replace detination register with shuffle result. + if (threadIdx.x & 0x4) { + dst0 = y0; + } + else { + dst1 = y0; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6cb2387cd824da13cc16e7efeb2c721969adf4a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm75.hpp @@ -0,0 +1,269 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if defined(__clang__) && defined(__CUDA__) + // ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046 + // ... but will not work until Clang 15: + // * https://reviews.llvm.org/D121666 + // * https://reviews.llvm.org/D126846 + #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) + #define CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75 (__clang_major__ >= 15) +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // ldmatrix PTX instruction added in CUDA 10.2+ + #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) + #define CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) + #define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED) + #define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 +#endif + +#if ! defined(CUTE_ARCH_MOVM_SM75_SUPPORTED) + #define CUTE_ARCH_MOVM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_MOVM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_MOVM_SM75) +#endif + +#if ! defined(CUTE_ARCH_MOVM_SM75_ENABLED) + #define CUTE_ARCH_MOVM_SM75_ENABLED (CUTE_ARCH_MOVM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_MOVM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_MOVM_SM75_ACTIVATED 1 +#endif + + +namespace cute +{ + +struct SM75_U32x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x2_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x4_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x1_MOVM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t src, + uint32_t &dst) + { +#if CUTE_ARCH_MOVM_SM75_ACTIVATED + asm volatile("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "=r"(dst) + : "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use movmatrix without CUTE_ARCH_MOVM_SM75_ACTIVATED."); +#endif + } +}; +// +// Legacy LDSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_ldsm(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_ldsm_trans(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..71a7b3a46616869d79e67586c6672fcd6eaac666 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm80.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED +#endif + +namespace cute +{ + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTE_HOST_DEVICE +void +cp_async_fence() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but N previous cp.async.commit_group operations have committed. +template +CUTE_HOST_DEVICE +void +cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } +#endif +} + +template +CUTE_HOST_DEVICE +void +cp_async_wait(Int) +{ + return cp_async_wait(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5c0745da0a205f304fed9c9ec257bf0bcf16a7cb --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90.hpp @@ -0,0 +1,219 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include + +namespace cute +{ + +struct SM90_U32x1_STSM_N +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t & smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x2_STSM_N +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x4_STSM_N +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x2_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x4_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x8_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +// +// Legacy STSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_stsm(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_stsm_trans(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..63bfb563b1a886b40e8b798d217c70e2db073df7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_desc.hpp @@ -0,0 +1,475 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_types.h" + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +#include + +#include // cute::cast_smem_ptr_to_uint +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns +/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) +/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) +////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Initialize barrier present in shared memory +CUTE_HOST_DEVICE +void +initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int thread_count = 1) // Thread count expected to arrive/wait on this barrier +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(thread_count)); +#endif +} + +// Set the number of bytes transfered per transaction and perform an arrive operation as well +CUTE_HOST_DEVICE +void +set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + uint32_t bytes) // Number of bytes transfered by per TMA transaction +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(bytes)); +#endif +} + +// Barrier wait +CUTE_HOST_DEVICE +void +wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int phase_bit) // Current phase bit the barrier waiting to flip +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra DONE;\n" + "bra LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(smem_int_ptr), + "r"(phase_bit)); + +#endif +} + +// Barrier arrive +CUTE_HOST_DEVICE +void +arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .b64 state; \n" + "mbarrier.arrive.shared::cta.b64 state, [%0];\n" + "}\n" + :: "r"(smem_int_ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA Descriptor and utilities +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMA { + +enum class SmemSwizzleBits : uint8_t { + DISABLE = 0, + B32 = 1, + B64 = 2, + B128 = 3, +}; + +enum class SmemSwizzleBase : uint8_t { + SWIZZLE_BASE_16B = 0, + + SWIZZLE_BASE_32B = 1, + SWIZZLE_BASE_32B_FLIP_8B = 2, + SWIZZLE_BASE_64B = 3, + +}; + +enum class OOBFill : uint8_t { + ZERO = 0, + CONSTANT = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(OOBFill const& t) { + switch (t) { + case OOBFill::ZERO: return "ZERO"; + case OOBFill::CONSTANT: return "CONSTANT"; + } + return nullptr; +} + +enum class L2Promotion : uint8_t { + DISABLE = 0, + B64 = 1, + B128 = 2, + B256 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(L2Promotion const& t) { + switch (t) { + case L2Promotion::DISABLE: return "DISABLE"; + case L2Promotion::B64: return "B64"; + case L2Promotion::B128: return "B128"; + case L2Promotion::B256: return "B256"; + } + return nullptr; +} + +// Aux parameters which are independent with the problem size +struct DescriptorAuxParams { + OOBFill oobfill_ = OOBFill::ZERO; + L2Promotion l2promo_ = L2Promotion::DISABLE; +}; + +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + + +enum class CacheHintSm100 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + + +#if (__CUDACC_VER_MAJOR__ >= 12) + +#if !defined(__CUDACC_RTC__) +/// @return The TMA descriptor datatype enum corresponding to T. +template +inline CUtensorMapDataType +to_CUtensorMapDataType() { + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8;} else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN16B; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_16U6_ALIGN16B; } else + #endif + + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } +} + +inline CUtensorMapSwizzle +to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { + switch (t) { + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBits::DISABLE: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: + switch (b) { + default: throw std::runtime_error("Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBase::SWIZZLE_BASE_16B: return CU_TENSOR_MAP_SWIZZLE_128B; + + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ > 6))) + case SmemSwizzleBase::SWIZZLE_BASE_32B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B; + case SmemSwizzleBase::SWIZZLE_BASE_64B: return CU_TENSOR_MAP_SWIZZLE_128B_ATOM_64B; + #endif + } + } +} + +inline CUtensorMapFloatOOBfill +to_CUtensorMapFloatOOBfill(OOBFill const& t) { + switch(t) { + default: throw std::runtime_error("Unknown OOBFill!"); + case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + } +} + +inline CUtensorMapL2promotion +to_CUtensorMapL2promotion(L2Promotion const& t) { + switch(t) { + default: throw std::runtime_error("Unknown L2Promotion!"); + case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; + case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; + case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + case L2Promotion::B256: return CU_TENSOR_MAP_L2_PROMOTION_L2_256B; + } +} + +#endif // !defined(__CUDACC_RTC__) + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + +} // end namespace TMA + +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + using TmaDescriptor = CUtensorMap; + using Im2ColTmaDescriptor = CUtensorMap; +#else + using TmaDescriptor = struct alignas(64) { char bytes[128]; }; + using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; }; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Initiates a TensorMap Prefetch +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) + asm volatile ( + "prefetch.tensormap [%0];" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a TensorMap modification (by each field) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Replace tensor pointer directly in GMEM +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" + :: "l"(gmem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" + :: "r"(smem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, + cute::array const& prob_shape, + cute::array const& prob_stride) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[3])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[4])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4])); + #else + // 4 LSBs are not included + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4)); + #endif +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + asm volatile ( + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" + :: "l"(gmem_int_desc), "r"(smem_int_desc)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a release fence operation (needed when modifying tensormap directly in GMEM) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_release() +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a acquire fence operation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ec1564499ec85d8ae3aa5b8a687ac37520b2ed7f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/copy_sm90_tma.hpp @@ -0,0 +1,1496 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include +#include "cutlass/arch/synclog.hpp" + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.1d.L2.global" + " [%0, {%1}];" + : + : "l"(gmem_int_desc), + "r"(crd0) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global" + " [%0, {%1, %2}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global" + " [%0, {%1, %2, %3}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cta.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#endif +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::PREFETCH::copy(desc_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::PREFETCH::copy(desc_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3, crd4); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2], {%6};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.im2col" + " [%0, {%1, %2, %3}], {%4};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.im2col" + " [%0, {%1, %2, %3, %4}], {%5, %6};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.im2col" + " [%0, {%1, %2, %3, %4, %5}], {%6, %7, %8};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) +#if defined(CUTE_ARCH_TMA_SM120_ENABLED) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE : Initiates a TMA copy from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n); + } +}; + +// Fence for smem stores for subsequent TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_fence() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile ("fence.proxy.async.shared::cta;"); +#elif defined(__CUDA_ARCH__) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +CUTE_HOST_DEVICE static void +tma_desc_commit_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete +template +CUTE_HOST_DEVICE static void +tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(Count) + : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +// Wait until all TMA descriptor previously issued are safe to be modified after tma_desc_commit_group() +CUTE_HOST_DEVICE static void +tma_desc_wait_group() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_REDUCE_ADD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_REDUCE_ADD_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_REDUCE_ADD_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_REDUCE_ADD_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_REDUCE_ADD_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_REDUCE_ADD_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// BULK_COPY : Copy a bulk of memory between shared memory and global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_COPY_G2S +{ + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, uint64_t* mbar_ptr, + void * smem_ptr, int32_t load_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(smem_int_ptr), "l"(gmem_ptr), "r"(load_bytes), "r"(smem_int_mbar) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, int32_t load_bytes) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" + : + : "l"(gmem_ptr), "r"(load_bytes) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_BULK_COPY_S2G +{ + CUTE_HOST_DEVICE static void + copy(void const* smem_ptr, + void * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_AUTO {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8b97f50ab4a9a9faaa6e1065c65938ccfc8cfa59 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::fma +#include // cute::fma + +namespace cute +{ + +// +// Direct FMA for any type +// + +template +struct UniversalFMA +{ + using DRegisters = D[1]; + using ARegisters = A[1]; + using BRegisters = B[1]; + using CRegisters = C[1]; + + CUTE_HOST_DEVICE static constexpr void + fma(D & d, + A const& a, + B const& b, + C const& c) + { + // Forward to an ADL/cute free function for these types + using cute::fma; + fma(d, a, b, c); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..749da8167e68d19a7c332d771da71d9172e49fd6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100.hpp @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// +// + +#pragma once + +#include +#include + +#include + +namespace cute { + +struct SM100_2x1x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float2[1]; + using BRegisters = float[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float2 const& a01, + float const& b0, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, a01, make_float2(b0, b0), c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_2x1x1_F32F32F32F32 without CUTE_ARCH_FLOAT2_MATH_ENABLED"); +#endif + } +}; + +struct SM100_1x2x1_F32F32F32F32 { + using DRegisters = float2[1]; + using ARegisters = float[1]; + using BRegisters = float2[1]; + using CRegisters = float2[1]; + + CUTE_HOST_DEVICE static void + fma(float2 & d01, + float const& a0, + float2 const& b01, + float2 const& c01) + { +#if defined(CUTE_ARCH_FFMA2_SM100_ENABLED) + cute::fma(d01, make_float2(a0, a0), b01, c01); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_1x2x1_F32F32F32F32 without CUTE_ARCH_FFMA2_SM100_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d255c41f179b9f76da96f093e2feef131c17cd87 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_desc.hpp @@ -0,0 +1,651 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include + +#include + +#include +#include // cute::array + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// UMMA Descriptor and utilities + +// UMMA enums and utilities +namespace UMMA +{ + +enum class Major : uint8_t { + K = 0, + MN = 1 +}; + +enum class ScaleIn : uint8_t { + One = 0, + Neg = 1 +}; + +enum class ScaleOut : uint8_t { + Zero = 0, + One = 1 +}; + +enum class Saturate : uint8_t { + False = 0, + True = 1 +}; + +enum class LayoutType : uint8_t { + SWIZZLE_NONE = 0, + SWIZZLE_128B_BASE32B = 1, + SWIZZLE_128B = 2, + SWIZZLE_64B = 4, + SWIZZLE_32B = 6 +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::SWIZZLE_NONE: return "SWIZZLE_NONE"; + case LayoutType::SWIZZLE_128B_BASE32B: return "SWIZZLE_128B_BASE32B"; + case LayoutType::SWIZZLE_128B: return "SWIZZLE_128B"; + case LayoutType::SWIZZLE_64B: return "SWIZZLE_64B"; + case LayoutType::SWIZZLE_32B: return "SWIZZLE_32B"; + } + return nullptr; +} + +union SmemDescriptor +{ + uint64_t desc_ = 0; + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + uint16_t stride_byte_offset_ : 14, version_ : 2; // 14 bits [0,14), 2 bits [14,16) + // base_offset, bit [49,52). leading_byte_offset_mode, bit [52,53). + uint8_t : 1, base_offset_ : 3, lbo_mode_ : 1, : 3; // 1 bit unused, 3 bits [1,4), 1 bit [4,5), 3 bits unused + // layout type, bit [61,64), SWIZZLE_NONE matrix descriptor = 0, SWIZZLE_128B matrix descriptor = 2, SWIZZLE_64B descriptor = 4, SWIZZLE_32B descriptor = 6, SWIZZLE_128B_BASE32B = 1, N/A = 3, N/A = 5, N/A = 7 + uint8_t : 5, layout_type_ : 3; // 6 bits unused, 3 bits [5,8) + }; + // Seperate the field, as we may only update one part of desc + struct { + uint32_t lo; + uint32_t hi; + }; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +enum class F16F32Format : uint8_t { + F16 = 0, + BF16 = 1, + TF32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(F16F32Format const& t) { + switch (t) { + case F16F32Format::F16: return "F16"; + case F16F32Format::BF16: return "BF16"; + case F16F32Format::TF32: return "TF32"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr F16F32Format to_F16F32Format() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + { static_assert(sizeof(T) == 0, "Unknown type for F16F32Format"); } +} + +enum class S8Format : uint8_t { + UINT8 = 0, + INT8 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(S8Format const& t) { + switch (t) { + case S8Format::UINT8: return "UINT8"; + case S8Format::INT8: return "INT8"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr S8Format to_S8Format() { + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + { static_assert(sizeof(T) == 0, "Unknown type for S8Format"); } +} + +enum class MXF8F6F4Format : uint8_t { + E4M3 = 0, + E5M2 = 1, + E2M3 = 3, + E3M2 = 4, + E2M1 = 5, + INVALID = 7 // an invalid datatype for runtime proxy type +}; + +CUTE_HOST_DEVICE char const* to_string(MXF8F6F4Format const& t) { + switch (t) { + case MXF8F6F4Format::E4M3: return "E4M3"; + case MXF8F6F4Format::E5M2: return "E5M2"; + case MXF8F6F4Format::E2M3: return "E2M3"; + case MXF8F6F4Format::E3M2: return "E3M2"; + case MXF8F6F4Format::E2M1: return "E2M1"; + case MXF8F6F4Format::INVALID: return "INVALID"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF8F6F4Format to_MXF8F6F4Format() { + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF8F6F4Format"); } +} + +enum class MXF4Format : uint8_t { + E2M1 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(MXF4Format const& t) { + switch (t) { + case MXF4Format::E2M1: return "E2M1"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr MXF4Format to_MXF4Format() { + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for MXF4Format"); } +} + +enum class ScaleFormat : uint8_t { + UE4M3 = 0, + UE8M0 = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(ScaleFormat const& t) { + switch (t) { + case ScaleFormat::UE4M3: return "UE4M3"; + case ScaleFormat::UE8M0: return "UE8M0"; + } + return nullptr; +} + +template +CUTE_HOST_DEVICE constexpr ScaleFormat to_ScaleFormat() { + if constexpr (is_same_v) { return ScaleFormat::UE4M3; } else + if constexpr (is_same_v) { return ScaleFormat::UE8M0; } else + { static_assert(sizeof(T) == 0, "Unknown type for ScaleFormat"); } +} + +enum class CFormat : uint8_t { + F16 = 0, + F32 = 1, + S32 = 2, +}; + +CUTE_HOST_DEVICE char const* to_string(CFormat const& t) { + switch (t) { + case CFormat::F16: return "F16"; + case CFormat::F32: return "F32"; + case CFormat::S32: return "S32"; + } + return nullptr; +} + +enum class MaxShift : uint8_t { + NoShift = 0, + MaxShift8 = 1, + MaxShift16 = 2, + MaxShift32 = 3 +}; + +enum class BMatrixBufferId : uint8_t { + Zero = 0u, + One = 1u, + Two = 2u, + Three = 3u +}; + +enum class BMatrixBufferReuse : uint8_t { + Keep = 1u, + Reuse = 2u, + ReuseAndKeep = 3u +}; + +// using MaskAndShiftB = uint32_t[2]; +union MaskAndShiftB +{ + uint32_t uri[2]; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint8_t start_count_ [4]; // bit [ 0:32) : 8 bits each. Specifies the start count for mask generation. + uint32_t first_span_ : 4, // bit [32:36) : 1 bit each. 0 = start where B is used. 1 = start with where B is skipped(0 value is used). + : 3, // + nzm_ : 1, // bit [39:40) : 0 = Enable the mask. 1 = Disable the mask. + skip_span_ : 8, // bit [40:48) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where 0 value is used. + use_span_ : 8, // bit [48:55) : Count-1 (zero encoded in this field specifies use span of 1) of consecutive columns where B matrix data is used. + shift_ : 6, // bit [56:62) : Shift value for B matrix data. + : 2; + }; +}; + +template +CUTE_HOST_DEVICE constexpr auto +make_column_zero_mask(ShapeType conv_q, int32_t cta_coord_q, int32_t num_pixels_skip_left) { + + static_assert(cute::is_same_v || cute::is_integral::value); + + cute::array column_zero_masks{}; + + static_assert(FLT_S == 3, "Filter size not supported."); + constexpr int MAX_USE_SPAN_COUNT = 256; + constexpr int MAX_SKIP_SPAN_COUNT = 256; + + // conv_q_int used for non-divmod case (add/minus/..) + // conv_q used for divmod case (div/mod/...) + int32_t conv_q_int = int(conv_q); + auto [_, cta_q] = divmod(cta_coord_q * CTA_N, conv_q); + + int step_q = CTA_M == 128 ? CTA_N / 1 + : CTA_M == 64 ? CTA_N / 2 + : CTA_M == 32 ? CTA_N / 4 + : 0; + + for (int mask_iter = 0; mask_iter < int(CTA_N / step_q); ++mask_iter) { + + for (int s_iter = 0; s_iter < FLT_S; s_iter += 1) { + + int32_t skip_span{0}, use_span{0}, nzm{1}, first_span{0}, start_count{0}, shift{0}; + + shift = s_iter; + + // Examples for CZM setting + // CASE0: (skip_span_ < 0) + // | padding |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // -skip_span 0 ^cta_q conv_q-1 + // 0 ^index + // + // CASE1: (skip_span_ > 0) + // |<- conv_q ->| + // |skip_span_|<- use_span ->|skip_span_| + // 0 ^cta_q conv_q-1 + // 0 ^index + // + // line 0 an input vector from 0 to conv_q with the padding + // line 1 shows the different spans we need to skip or load + // lines 2-3 show the different coordinates of different boundaries. + // CTQ_q is the coordinate of the present cta. + + int32_t skip_span_ = num_pixels_skip_left - shift; + int32_t index{0}; + if (skip_span_ > 0) { + auto [_, index_mod] = divmod(cta_q, conv_q); + index = index_mod; + } else if (skip_span_ < 0) { + auto [_, index_mod] = divmod((cta_q - skip_span_), conv_q); + index = index_mod; + } else { + nzm = 0; + } + skip_span = cute::max(cute::abs(skip_span_), 1); + use_span = cute::min(conv_q_int - static_cast(skip_span), MAX_USE_SPAN_COUNT); + if (use_span > 0) { + first_span = index >= skip_span ? 0 : 1; + if ((first_span == 0) && (index + CTA_N < conv_q_int + skip_span)) { + nzm = 0; + } else { + start_count = first_span == 0 ? (use_span - (conv_q_int - index)) : index; + } + } else { + skip_span = MAX_SKIP_SPAN_COUNT; + use_span = 1; + first_span = 1; + start_count = 0; + } + + column_zero_masks[s_iter].start_count_[mask_iter] = start_count; + column_zero_masks[s_iter].first_span_ |= first_span << mask_iter; + column_zero_masks[s_iter].nzm_ |= nzm; + column_zero_masks[s_iter].skip_span_ = skip_span - 1; + column_zero_masks[s_iter].use_span_ = use_span - 1; + column_zero_masks[s_iter].shift_ = shift; + + } + + cta_q += step_q; + } + + return column_zero_masks; +} + +template +CUTE_HOST_DEVICE constexpr auto to_UMMAFormat() { + if constexpr (is_same_v) { return F16F32Format::F16; } else + if constexpr (is_same_v) { return F16F32Format::BF16; } else + if constexpr (is_same_v) { return F16F32Format::TF32; } else + if constexpr (is_same_v) { return S8Format::UINT8; } else + if constexpr (is_same_v) { return S8Format::INT8; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + + if constexpr (is_same_v) { return MXF8F6F4Format::E4M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E5M2; } else + if constexpr (is_same_v) {return MXF8F6F4Format::INVALID; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M3; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E3M2; } else + if constexpr (is_same_v) { return MXF8F6F4Format::E2M1; } else + if constexpr (is_same_v) { return MXF4Format::E2M1; } else + { static_assert(sizeof(T) == 0, "Unknown type for UMMAFormat"); } +} + +template +CUTE_HOST_DEVICE constexpr CFormat to_CFormat() { + if constexpr (is_same_v) { return CFormat::F16; } else + if constexpr (is_same_v) { return CFormat::F32; } else + if constexpr (is_same_v) { return CFormat::S32; } else + { static_assert(sizeof(T) == 0, "Unknown type for CFormat"); } +} + +union InstrDescriptor +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + saturate_ : 1, // bit [ 3, 4) : 0 = no saturate. 1 = saturate. 1 value valid only for S8 + c_format_ : 2, // bit [ 4, 6) : 0 = F16. 1 = F32, 2 = S32 + : 1, // + a_format_ : 3, // bit [ 7,10) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + b_format_ : 3, // bit [10,13) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. Boolean MMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + : 1, // + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + : 1, // + max_shift_ : 2; // bit [30,32) : Maximum shift for WS instruction. Encoded as follows: 0 = no shift, 1 = maximum shift of 8, 2 = maximum shift of 16, 3 = maximum shift of 32. + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr explicit + operator uint32_t() const noexcept { return desc_; } +}; + +union InstrDescriptorBlockScaled +{ + uint32_t desc_; + + struct { + // Bitfield implementation avoids the need for shifts in assignment + uint16_t sparse_id2_ : 2, // bit [ 0, 2) : Sparse meta data id2 + sparse_flag_ : 1, // bit [ 2, 3) : 0 = dense. 1 = sparse. 1 value valid only for F32F16/S8/MXF8F6F4 + : 1, // + b_sf_id_ : 2, // bit [ 4, 6) : Matrix B Scale Factor ID + : 1, // + a_format_ : 3, // bit [ 7, 9) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + b_format_ : 3, // bit [10,12) : MXF8F6F4Format:0 = E4M3, 1 = E5M2, 3 = E2M3, 4 = E3M2, 5 = E2M1. F32F16Format: 0 = F16, 1 = BF16, 2 = TF32. S8: 0 unsigned 8 bit, 1 signed 8 bit. BMMA: 0 Boolean + a_negate_ : 1, // bit [13,14) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + b_negate_ : 1, // bit [14,15) : 0 = no negate. 1 = negate. 1 value valid only for F32F16Format and MXF8F6F4Format + a_major_ : 1; // bit [15,16) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + uint16_t b_major_ : 1, // bit [16,17) : 0 = K-major. 1 = MN-major. Major value of 1 is only valid for E4M3, E5M2, INT8 (signed and unsigned), F16, BF16 and TF32 source formats + n_dim_ : 6, // bit [17,23) : 3 LSBs not included. Valid values range from 1 (N=8) to 32 (N=256). All values are not valid for all instruction formats + scale_format_ : 1, // bit [23,24) : 0=E4M3, 1=E8M0 + m_dim_ : 5, // bit [24,29) : 4 LSBs not included. Valid values are: 4 (M=64), 8 (M=128), 16 (M=256) + a_sf_id_ : 2, // bit [29,31) : Matrix A Scale Factor ID + k_size_ : 1; // bit [31,32) : MMA-K Dim. MXF8F6F4Format: 0=[dense: K32, sparse: K64]. S8Format: 0=[dense: K32, sparse: invalid]. MXF4Format: 0=[dense: K64, sparse: K128], 1=[dense: K96, sparse: invalid]. + }; + + // Decay to a uint32_t + CUTE_HOST_DEVICE constexpr + operator uint32_t() const noexcept { return desc_; } +}; + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptor +make_instr_desc() +{ + UMMA::InstrDescriptor desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.c_format_ = uint8_t(UMMA::to_CFormat()); + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.saturate_ = uint8_t(c_sat); + + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + desc_i.max_shift_ = uint8_t(max_shift); + + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) { + UMMA::InstrDescriptor desc_i = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat, is_sparse, + max_shift>(); + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc(UMMA::InstrDescriptor desc_i, uint16_t sparse_id2 = 0u, uint32_t tmem_e = 0u) +{ + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE constexpr +UMMA::InstrDescriptorBlockScaled +make_instr_desc_block_scaled() +{ + UMMA::InstrDescriptorBlockScaled desc_i = {}; + + desc_i.a_format_ = uint8_t(UMMA::to_UMMAFormat()); + desc_i.b_format_ = uint8_t(UMMA::to_UMMAFormat()); + + desc_i.scale_format_ = uint8_t(UMMA::to_ScaleFormat()); + desc_i.a_sf_id_ = 0; + desc_i.b_sf_id_ = 0; + + desc_i.m_dim_ = (M >> 4); + desc_i.n_dim_ = (N >> 3); + + desc_i.a_major_ = uint8_t(a_major); + desc_i.b_major_ = uint8_t(b_major); + + desc_i.a_negate_ = uint8_t(a_neg); + desc_i.b_negate_ = uint8_t(b_neg); + desc_i.sparse_flag_ = is_sparse; // 1 = Sparse + desc_i.sparse_id2_ = 0; + + // Below would bring some warnings. +#if defined(__GNUC__) +# pragma GCC diagnostic ignored "-Wconversion" +#endif + return desc_i; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + UMMA::InstrDescriptorBlockScaled desc_i = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, + a_major, b_major, + a_neg, b_neg, + is_sparse>(); + + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +template +CUTE_HOST_DEVICE +constexpr uint64_t +make_runtime_instr_desc_block_scaled(UMMA::InstrDescriptorBlockScaled desc_i, + uint32_t const tmem_sfa_addr, uint32_t const tmem_sfb_addr, + uint16_t const sparse_id2 = 0u, uint32_t const tmem_e = 0u) +{ + // The first 2-bits of TMEM address includes byte address. + desc_i.a_sf_id_ = (tmem_sfa_addr & 0xC0000000) >> 30; + desc_i.b_sf_id_ = (tmem_sfb_addr & 0xC0000000) >> 30; + + if constexpr (is_sparse) { + desc_i.sparse_id2_ = sparse_id2; + } + else { + assert(sparse_id2 == 0u); + } + + // In current compiler exposure, idescE is a uint64_t. It should contain: + // - Lower 32b URe: Specifies the tmem address that stores the sparse metadata. + // Only needed for Sparse MMA instructions. Otherwise, ignored. + // - Upper 32b URh: Specifies the instruction descriptor. + uint64_t idescE = (static_cast(static_cast(desc_i)) << 32); + + return idescE; +} + +} // end namespace UMMA +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..25b504dc0b2bfc36f494e70b78801a5b10c3c553 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm100_umma.hpp @@ -0,0 +1,1777 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct SM100_MMA_TF32_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_TF32 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16 M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16 N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16 A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + uint32_t mask[4] = {0, 0, 0, 0}; + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS without CUTE_ARCH_MMA_SM100A_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16_SS_SCALED N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_TS_SCALED +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_TS_SCALED M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F16BF16_TS_SCALED N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p, %9; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_TF32_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_TF32_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::tf32 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F16BF16_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), + "SM100_MMA_F16BF16_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f16 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_TF32_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::tf32 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_TS without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SCALED N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_TS_SCALED +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_TS_SCALED M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_TS_SCALED N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F16BF16_2x1SM_TS_SCALED A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& accumulate, + uint64_t idescE) + { +#if defined(CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED) + if (cute::elect_one_sync()) { + // ScaleC input should be a literal or compile time constant + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p, %13; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(accumulate), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "n"(ScaleC)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_TS_SCALED without CUTE_ARCH_TCGEN05_F16BF16_MMA_SCALED_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_TF32_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_TF32_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_TF32_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::tf32 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_TF32_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_TF32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F16BF16_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_F16BF16_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f16 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F16BF16_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_F16F32_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_SS +{ + static_assert(is_same_v, "SM100_MMA_S8_SS result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS N-mode size should be 8 or a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_S8_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_TS N-mode size should be 8 or a multiple of 16 between 16 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_SS_SPARSE +{ + static_assert(is_same_v, "SM100_MMA_S8_SS_SPARSE result type can only be int32_t."); + static_assert(M == 64 || M == 128, "SM100_MMA_S8_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(N == 8 || ((N % 16 == 0) && (16 <= N) && (N <= 256)), "SM100_MMA_S8_SS_SPARSE N-mode size should be 8 or a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::i8 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8_2x1SM_TS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_S8_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::i8 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_TS without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_S8_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_S8 M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_S8 N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_S8_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::i8 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_S8_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_S8_MMA_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS +{ + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_TS +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_TS N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_2x1SM_TS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_TS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_TS N-mode size should be a multiple of 32 between 32 and 256."); + static_assert(a_major == UMMA::Major::K, "SM100_MMA_F8F6F4_2x1SM_TS A from TMEM can't be transposed"); + + using DRegisters = void; + using ARegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& tmem_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], [%1], %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "r"(tmem_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_TS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_SS_SPARSE +{ + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS_SPARSE M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert((M == 64 && (N % 8 == 0) && (8 <= N) && (N <= 256)) || + (M == 128 && (N % 16 == 0) && (16 <= N) && (N <= 256)), + "SM100_MMA_F8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256 for M=64,\ + or a multiple of 16 between 16 and 256 for M=128."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[4] = {0, 0, 0, 0}; // %5, %6, %7, %8 + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::f8f6f4 [%0], %1, %2, [%9], %3, {%5, %6, %7, %8}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF8F6F4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::1.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +struct SM100_MMA_F8F6F4_2x1SM_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7])); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE M-mode size should be 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF8F6F4_2x1SM_SS +{ + static_assert(M == 256, "SM100_MMA_MXF8F6F4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.cta_group::2.kind::mxf8f6f4.block_scale [%0], %1, %2, %3, [%5], [%6], p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF8F6F4_2x1SM_SS without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_F8F6F4_2x1SM_SS_SPARSE +{ + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 32 == 0) && (32 <= N) && (N <= 256), "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE N-mode size should be a multiple of 32 between 32 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tmem_e) + { +#if defined(CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED) + if (cute::elect_one_sync()) { + uint32_t mask[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" + "tcgen05.mma.sp.cta_group::2.kind::f8f6f4 [%0], %1, %2, [%13], %3, {%5, %6, %7, %8, %9, %10, %11, %12}, p; \n\t" + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(mask[0]), "r"(mask[1]), "r"(mask[2]), "r"(mask[3]), + "r"(mask[4]), "r"(mask[5]), "r"(mask[6]), "r"(mask[7]), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_F8F6F4_2x1SM_SS_SPARSE without CUTE_ARCH_TCGEN05_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +template +struct SM100_MMA_MXF4_SS +{ + static_assert(M == 128, "SM100_MMA_MXF4_SS M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4_SS N-mode size should be a multiple of 8 between 8 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_SS Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } + +}; + +template +struct SM100_MMA_MXF4NVF4_SS_SPARSE +{ + static_assert(M == 128, "SM100_MMA_MXF4NVF4_SS_SPARSE M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 8 == 0) && (8 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_SS_SPARSE N-mode size should be a multiple of 8 between 8 and 256."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::1.kind::mxf4.block_scale.block32 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::1.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +template +struct SM100_MMA_MXF4_2x1SM_SS +{ + static_assert(M == 128 || M == 256, "SM100_MMA_MXF4_2x1SM_SS M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 16) || (VS == 32), "SM100_MMA_MXF4_2x1SM_SS Vector size can only be 16 or 32."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { + if constexpr (VS == 16) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 16) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4_2x1SM_SS (VS = 32) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +template +struct SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE +{ + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE N-mode size should be a multiple of 16 between 16 and 256."); + static_assert((VS == 32) || (VS == 64), "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE Vector size can only be 32 or 64."); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr, + uint32_t const& tmem_e) + { + if constexpr (VS == 32) { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 32) without CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ENABLED"); +#endif + } + + if constexpr (VS == 64) { +#if defined(CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED) + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.sp.cta_group::2.kind::mxf4.block_scale.block32 [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.sp.cta_group::2.kind::mxf4.block_scale.scale_vec::2X [%0], %1, %2, [%7], %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr), "r"(tmem_e)); + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE (VS = 64) without CUTE_ARCH_TCGEN05_MXF4_MMA_ENABLED"); +#endif + } + } +}; + +namespace SM103 { +template +struct SM103_MXF4_ULTRA_SS_VS +{ + static_assert(M == 128, "MMA M-mode size should be 128 for 1 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::1.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; + + +template +struct SM103_MXF4_ULTRA_2x1SM_SS_VS +{ + static_assert(M == 128 || M == 256, "MMA M-mode size should be 128 or 256 for 2 CTA cluster MMA."); + static_assert((N % 16 == 0) && (16 <= N) && (N <= 256), "MMA N-mode size should be a multiple of 16 between 16 and 256."); + static_assert(((VS == 32) & (is_same_v && is_same_v)) || (VS == 16), + "Vector size can only be 4x mode (VS=16) or 2x mode (VS=32) for MMA. 2x mode only supports float_e2m1_t for a/b types and ue8m0_t for sf type"); + + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[1]; + using SFARegisters = uint32_t[1]; + using SFBRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t const& tmem_c, + uint32_t const& scaleC, + uint64_t const& idescE, + uint32_t const& tsfa_addr, + uint32_t const& tsfb_addr) + { +#if defined(CUTE_ARCH_TCGEN05_MXF4NVF4_MMA_ULTRA_ENABLED) + if constexpr (VS == 16) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block16 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::4X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } + else if constexpr (VS == 32) { + if (cute::elect_one_sync()) { + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.ne.b32 p, %4, 0;\n\t" +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 9) + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.block32 [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#else + "tcgen05.mma.cta_group::2.kind::mxf4nvf4.block_scale.scale_vec::2X [%0], %1, %2, %3, [%5], [%6], p; \n\t" +#endif + "}\n" + : + : "r"(tmem_c), "l"(desc_a), "l"(desc_b), "r"(uint32_t(idescE>>32)), "r"(scaleC), + "r"(tsfa_addr), "r"(tsfb_addr)); + } + } +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM103_MXF4_ULTRA_2x1SM_SS_VS without CUTE_ARCH_MMA_SM103A_ENABLED"); +#endif + } + +}; +} // namespace SM103 + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1433a2c8d0b558ab12d2d6a0fbe028d77a63e689 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120.hpp @@ -0,0 +1,3254 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN for given data types."); +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + + +// MMA 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E2M1 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M1 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E3M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E2M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.F16 16x8x32 TN E4M3 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E4M3 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M1 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E3M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E2M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E4M3 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.F16 16x8x32 TN E5M2 x E5M2 +template <> +struct SM120_16x8x32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_16x8x32_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM120::BLOCKSCALED { +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x32_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x32_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_16x8x64_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M1 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E3M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E2M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E4M3 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M1 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E3M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E2M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E4M3 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x32 TN E5M2 x E5M2 with SF UE8M0 +template +struct SM120_16x8x32_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + uint8_t const& sfa0, + uint8_t const& sfb0) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.block_scale.scale_vec::1X.m16n8k32.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); + +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x32_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); + +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +// MMA.SF 16x8x64 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + static constexpr int SFBits = (VS == 32) ? 16 : 32; + using RegTypeSF = uint_bit_t; + + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3, + RegTypeSF const& sfa0, + RegTypeSF const& sfb0) + { + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 16 || VS == 32, "Scaling factor vector size has to be 16 or 32 for MXF4NVF4 MMA."); +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.block_scale.scale_vec::4X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13}," + "{%14}," + "{%15, %16}," + "{%17}," + "{%18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(uint32_t(sfa0)) , "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb0)) , "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::BLOCKSCALED::SM120_16x8x64_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM120::BLOCKSCALED + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_op_selector_sm120() +{ + return SM120_16x8x32_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SM120_16x8x32_TN_VS{}; + } + else{ + return SM120::BLOCKSCALED::SM120_16x8x64_TN_VS{}; + } +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a695003216d2c7553ab4b62ddd12594eabc576f2 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm120_sparse.hpp @@ -0,0 +1,3444 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::float_e4m3_t, etc +#include +#include + +namespace cute { + +namespace SM120::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m1.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e3m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e2m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e4m3.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m1.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e3m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e2m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e4m3.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f32.e5m2.e5m2.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "%16, 0x0;\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M1 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m1.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E3M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e3m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E2M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e2m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E4M3 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e4m3.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M1 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m1.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E3M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e3m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E2M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e2m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E4M3 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e4m3.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN with FP16 ACC and inputs E5M2 x E5M2 +template <> +struct SM120_SPARSE_16x8x64_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = uint32_t[2]; + using ERegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + uint32_t const& c0, uint32_t const& c1, + uint32_t const& e) + { +#if defined(CUTE_ARCH_F8F6F4_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::f8f6f4.sp::ordered_metadata.m16n8k64.row.col.f16.e5m2.e5m2.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7, %8, %9}," + "{%10, %11}," + "%12, 0x0;\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "r"(c0), "r"(c1), + "r"(e)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN without CUTE_ARCH_F8F6F4_MMA_ENABLED"); +#endif + } +}; + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x64_TN_VS for given data types."); +}; + +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + static_assert(cutlass::detail::dependent_false, "No MMA matches SM120_SPARSE_16x8x128_TN_VS for given data types."); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M1 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m1.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E3M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e3m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E2M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e2m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E4M3 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e4m3.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M1, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E3M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e3m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E2M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e2m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E4M3, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e4m3.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x64 TN with FP32 ACC and inputs E5M2 x E5M2, SF UE8M0 +template +struct SM120_SPARSE_16x8x64_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + using SFARegisters = uint8_t[1]; + using SFBRegisters = uint8_t[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, uint8_t const& sfa, uint8_t const& sfb) + { +#if defined(CUTE_ARCH_MXF8F6F4_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64, "Scaling factor vector size has to be 64 for MXF8F6F4 MMA."); + + asm volatile( + "mma.sync.aligned.kind::mxf8f6f4.sp::ordered_metadata.block_scale.scale_vec::1X.m16n8k64.row.col.f32.e5m2.e5m2.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}," + "{%18, %19}," + "{%20}," + "{%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"((uint32_t)sfa), "h"(bidA), "h"(tidA), + "r"((uint32_t)sfb), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120_SPARSE_16x8x64_TN_VS without CUTE_ARCH_MXF8F6F4_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF UE8M0 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { + + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 64 || VS == 32, "Scaling factor vector size has to be 64 or 32 for MXF4NVF4."); +#if defined(CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED) + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::2X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue8m0 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_2X_UE8M0_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA SPARSE BLOCKSCALED 16x8x128 TN E2M1 x E2M1 with SF E4M3 +template +struct SM120_SPARSE_16x8x128_TN_VS +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[4]; + using CRegisters = float[4]; + using ERegisters = uint32_t[1]; + + static constexpr int SFBits = (VS == 64) ? 16 : 32; + using RegTypeSF = uint_bit_t; + using SFARegisters = RegTypeSF[1]; + using SFBRegisters = RegTypeSF[1]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, uint32_t const& b2, uint32_t const& b3, + float const& c0, float const& c1, float const& c2, float const& c3, + uint32_t const& e, RegTypeSF const& sfa, RegTypeSF const& sfb) + { +#if defined(CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED) + static constexpr uint16_t tidA = 0; + static constexpr uint16_t bidA = 0; + static constexpr uint16_t tidB = 0; + static constexpr uint16_t bidB = 0; + + CUTE_STATIC_ASSERT(VS == 32, "Scaling factor vector size has to be 32 for NVF4 with e2m1 and scale factor e4m3."); + asm volatile( + "mma.sync.aligned.kind::mxf4nvf4.sp::ordered_metadata.block_scale.scale_vec::4X.m16n8k128.row.col.f32.e2m1.e2m1.f32.ue4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16}, 0x0," + "{%17}, {%18, %19}," + "{%20}, {%21, %22};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), "r"(b2), "r"(b3), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "r"(e), + "r"(uint32_t(sfa)), "h"(bidA), "h"(tidA), + "r"(uint32_t(sfb)), "h"(bidB), "h"(tidB)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM120::SPARSE::SM120_SPARSE_16x8x128_TN_VS without CUTE_ARCH_MXF4NVF4_4X_UE4M3_MMA_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace SM120::BLOCKSCALED::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ElementA, + class ElementB, + class ElementC +> +CUTE_HOST_DEVICE constexpr +auto +rr_sparse_op_selector_sm120() +{ + // Get MMA SPARSE OP + return SM120::SPARSE::SM120_SPARSE_16x8x64_TN{}; +} + +template < + class ElementA, + class ElementB, + class ElementC, + class ElementSF, + int SFVecSize, + bool UseF8F6F4 +> +CUTE_HOST_DEVICE constexpr +auto +rr_blockscaled_sparse_op_selector_sm120() +{ + if constexpr (UseF8F6F4) { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x64_TN_VS{}; + } + else { + return SM120::BLOCKSCALED::SPARSE::SM120_SPARSE_16x8x128_TN_VS{}; + } +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fce9a47f02a982a1c02c438fc582aaec3e99b157 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm61.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) +# define CUTE_ARCH_MMA_SM61_ENABLED +#endif + +namespace cute +{ + +struct SM61_DP4A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +struct SM61_DP2A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2acd421dda8f8247201f7c399cae220b3355105f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm70.hpp @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +# define CUTE_ARCH_MMA_SM70_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_ARCH_MMA_SM70_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM70 MMA 884 F16F16F16 +// + +struct SM70_8x8x4_F16F16F16F16_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_TT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM70 MMA 884 F16F16F32 +// + +struct SM70_8x8x4_F32F16F16F32_TN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_TT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7d34e2b45fbd3352d7b52ed0f62c8fc7134add4c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm75.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) +# define CUTE_ARCH_MMA_SM75_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_MMA_SM75_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM75 MMA 1688 F16F16F32 +// + +struct SM75_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM75 MMA 8816 S8S8S32 +// + +struct SM75_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e0dbc630760210989eb073576e7ec1693f44f240 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm80.hpp @@ -0,0 +1,2241 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTE_ARCH_MMA_B1_AND_SM80_ENABLED +#endif + +#if (__CUDA_ARCH__ <= 890) +#define CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + + + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM80_16x8x4_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x4 TN +struct SM80_8x8x4_F64F64F64F64_TN +{ + using DRegisters = double[2]; + using ARegisters = double[1]; + using BRegisters = double[1]; + using CRegisters = double[2]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, + double const& a0, + double const& b0, + double const& c0, double const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=d"(d0), "=d"(d1) + : "d"(a0), + "d"(b0), + "d"(c0), "d"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +// MMA 8x8x4 TN with Planar Complex multiplication +struct SM80_8x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = complex[2]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex const& a0, + complex const& b0, + complex const& c0, complex const& c1) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + a0.real(), + b0.real(), + c0.real(), c1.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.imag(), + b0.real(), + c0.imag(), c1.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + -a0.imag(), + b0.imag(), + d0.real(), d1.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.real(), + b0.imag(), + d0.imag(), d1.imag()); + } +}; + +// MMA 8x8x4 TN with Gaussian Complex multiplication: +// (a + bi)*(c + di) +// yields +// t0 += a*c +// t1 += b*d +// t2 += (a+b)*(c+d) +// then +// re = t0 - t1 +// im = t2 - t0 - t1 +struct SM80_8x8x4_GC64C64C64GC64_TN +{ + struct GaussComplex { + double t0, t1, t2; + + CUTE_HOST_DEVICE //constexpr + operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } + + CUTE_HOST_DEVICE friend //constexpr + complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator*(complex const& a, GaussComplex const& b) { return b * a; } + + CUTE_HOST_DEVICE friend //constexpr + complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator+(complex const& a, GaussComplex const& b) { return b + a; } + }; + + using DRegisters = GaussComplex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = GaussComplex[2]; + + CUTE_HOST_DEVICE static void + fma(GaussComplex & d0, GaussComplex & d1, + complex const& a0, + complex const& b0, + GaussComplex const& c0, GaussComplex const& c1) + { + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, + a0.real(), + b0.real(), + c0.t0, c1.t0); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, + a0.imag(), + b0.imag(), + c0.t1, c1.t1); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, + a0.real() + a0.imag(), + b0.real() + b0.imag(), + c0.t2, c1.t2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b6c440944afecc029cea55e29ab79c8d563c7ded --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm89.hpp @@ -0,0 +1,296 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) +# define CUTE_ARCH_MMA_F32_SM89_SUPPORTED +#endif + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 8) +# define CUTE_ARCH_MMA_F16_SM89_SUPPORTED +#endif + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) +# if defined(CUTE_ARCH_MMA_F32_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F32_SM89_ENABLED +# endif + +# if defined(CUTE_ARCH_MMA_F16_SM89_SUPPORTED) +# define CUTE_ARCH_MMA_F16_SM89_ENABLED +# endif +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cute { +// MMA 16x8x32 TN +struct SM89_16x8x32_F32E4M3E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E4M3E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E4M3E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E5M2F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E5M2F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F32E5M2E4M3F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_F32_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F32E5M2E4M3F32_TN without CUTE_ARCH_MMA_F32_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E4M3E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E4M3E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e4m3.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E4M3E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E4M3F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e4m3.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E4M3F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +struct SM89_16x8x32_F16E5M2E5M2F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_F16_SM89_ENABLED) + asm( + "mma.sync.aligned.m16n8k32.row.col.f16.e5m2.e5m2.f16 " + "{%0,%1}, {%2,%3,%4,%5}, {%6,%7}, {%8,%9};\n" + : "=r"(d0), "=r"(d1) + : + "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1) + ); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM89_16x8x32_F16E5M2E5M2F16_TN without CUTE_ARCH_MMA_F16_SM89_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6aa5fba2b2da6d04d3aea259d7b2047c12bd960e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90.hpp @@ -0,0 +1,9331 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_F64_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +namespace SM90 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[2]; + using BRegisters = double[1]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, + double const& b0, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), + "d"(b0), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[4]; + using BRegisters = double[2]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& b0, double const& b1, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(b0), "d"(b1), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[8]; + using BRegisters = double[4]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& a4, double const& a5, double const& a6, double const& a7, + double const& b0, double const& b1, double const& b2, double const& b3, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7, %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16, %17, %18, %19};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(a4), "d"(a5), "d"(a6), "d"(a7), + "d"(b0), "d"(b1), "d"(b2), "d"(b3), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[2]; + using BRegisters = complex[1]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& b0, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), + b0.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), + b0.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), + b0.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), + b0.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[4]; + using BRegisters = complex[2]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& b0, complex const& b1, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.real(), b1.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + b0.real(), b1.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + b0.imag(), b1.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.imag(), b1.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[8]; + using BRegisters = complex[4]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& a4, complex const& a5, + complex const& a6, complex const& a7, + complex const& b0, complex const& b1, + complex const& b2, complex const& b3, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + a4.imag(), a5.imag(), a6.imag(), a7.imag(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include // cute::size +#include // cute::is_static +#include // cute::half_t, cute::float_e4m3_t, cute::tfloat32_t, etc +#include // cute::is_same_v + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +namespace SM90::GMMA { + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +} // end namespace SM90::GMMA +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5f65746b34f44edc95c6f3af4b284e7fa8a0cd18 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_desc.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA Descriptor and utilities + +// GMMA enums and utilities +namespace SM90::GMMA { + +enum class LayoutType : uint8_t { + INTERLEAVE = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::INTERLEAVE: return "INTERLEAVE"; + case LayoutType::B128: return "B128"; + case LayoutType::B64: return "B64"; + case LayoutType::B32: return "B32"; + } + return nullptr; +} + +#if !defined(__CUDACC_RTC__) +// Output operator for all enums in this namespace +CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { + char const* s = to_string(t); + if (s) { + std::operator<<(os, s); // Explicit call to avoid ambiguity + } else { + os.setstate(std::ios_base::failbit); + } + return os; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace SM90::GMMA + +union GmmaDescriptor +{ + CUTE_HOST_DEVICE constexpr + GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +// Printer +CUTE_HOST_DEVICE void +print(GmmaDescriptor const& t) +{ +#if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); +#endif // !defined(__CUDACC_RTC__) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c0ce4a2e7198ac59c9c6a013a2b75a352b50a5b6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma.hpp @@ -0,0 +1,20974 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Warpgroup sync primitives + +CUTE_HOST_DEVICE +void +warpgroup_arrive() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_arrive(__LINE__); + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +template +CUTE_HOST_DEVICE +void +warpgroup_wait() +{ + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +// Marks the commit point for one or more sized batch of warpgroup MMAs. +CUTE_HOST_DEVICE +void +warpgroup_commit_batch() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_commit_batch(__LINE__); + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(uint32_t& reg) { + // MSVC emits a build error for 'asm volatile' + // even if it only occurs in a __device__ function. + // This prevents the error. +#if defined(__CUDA_ARCH__) + asm volatile("" : "+r"(reg) :: "memory"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(float& reg) { +#if defined(__CUDA_ARCH__) + asm volatile("" : "+f"(reg) :: "memory"); +#endif +} + +namespace SM90::GMMA { + +enum class Major { + K = 0, + MN = 1 +}; + +enum class ScaleOut { + Zero = 0, + One = 1 +}; + +enum class ScaleIn { + Neg = -1, + One = 1 +}; + +enum class SparseSel { + Zero = 0, + One = 1 +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f697b94912f843b55e86c86849a20eb1625cb097 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_ext.hpp @@ -0,0 +1,56445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA { + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd97a3dd64ae07ad21640689eeef0d63030e52e7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -0,0 +1,22745 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // GMMA::Major, etc. + +#include "cutlass/arch/synclog.hpp" + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8945551a1c94801d3a902ce47c57efea0e7cafdf --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,60445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1c07a31e6d3134e1c0437c438a5183081f50815d --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/simd_sm100.hpp @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +// + +// + +#pragma once + +#include +#include +#include + +namespace cute { + +CUTE_HOST_DEVICE +void +add(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("add.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + add(c.x, a.x, b.x); + add(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +mul(float2 & c, + float2 const& a, + float2 const& b) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("mul.f32x2 %0, %1, %2;\n" + : "=l"(reinterpret_cast(c)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b))); +#else + mul(c.x, a.x, b.x); + mul(c.y, a.y, b.y); +#endif +} + +CUTE_HOST_DEVICE +void +fma(float2 & d, + float2 const& a, + float2 const& b, + float2 const& c) +{ +#if defined(CUTE_ARCH_FLOAT2_MATH_ENABLED) + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;\n" + : "=l"(reinterpret_cast(d)) + : "l"(reinterpret_cast(a)), + "l"(reinterpret_cast(b)), + "l"(reinterpret_cast(c))); +#else + fma(d.x, a.x, b.x, c.x); + fma(d.y, a.y, b.y, c.y); +#endif +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..680e237f76d158846236c3317668b150fe162098 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/tmem_allocator_sm100.hpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +namespace cute::TMEM { + +// +// TMEM Addressing Constants +// + +// 128 DP x 512 COL x uint32_t-addressing +using MAX_CAPACITY_BITS = Int<128*512*32>; + +// TMEM DP stride in bit-addressing (shift by 5 for conversion from uint32_t) +using DP_b = cute::constant; + +// TMEM DP stride in type-T addressing +template +using DP = cute::constant::OffsetShift)>; + +// +// TMEM Allocators +// + +// All operations of this class require that only a single warp uniformly participates +class Allocator1Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator1Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + **/ + CUTE_HOST_DEVICE void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +class Allocator2Sm { +public: + static constexpr int ColumnsPerAllocationSlice = 32; + static constexpr int Sm100TmemCapacityColumns = 512; + + __device__ Allocator2Sm() { } + + /** + * Performs a non-blocking allocation of TMEM. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @param dst_ptr Pointer to shared memory to which to write the result tmem pointer to. + * Both CTAs _must_ provide the exact same dst_ptr for correctness. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre For repeated allocations, the same warp must be used to issue all allocations. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + **/ + CUTE_HOST_DEVICE void + allocate(int num_columns, uint32_t* dst_ptr) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + uint32_t dst_intptr = cute::cast_smem_ptr_to_uint(dst_ptr); + asm volatile( + "tcgen05.alloc.cta_group::2.sync.aligned.shared::cta.b32 [%0], %1;" + : + : "r"(dst_intptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + /** + * Frees the TMEM corresponding to the pointer and slice count provided. + * Release the TMEM after checking that the CTA issuing the free does indeed own the corresponding slices. + * @param tmem_ptr Base address of the TMEM address space being freed. + * @param num_columns Number of columns being freed. Must be 32 <= num_columns <= 512 and power of 2. + * @pre Must be issued by a single fully active warp of the CTA. + * @pre Must never be issued by more than one warp at the same time. + * @pre The 2 warps from participating CTAs have the same logical warp ID. + * @returns true + **/ + __device__ + void + free(uint32_t tmem_ptr, int num_columns) { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile( + "{\n\t" + "tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1; \n\t" + "}" + : + : "r"(tmem_ptr), "r"(num_columns)); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } + + __device__ + void + release_allocation_lock() { + #if defined(CUTE_ARCH_TCGEN05_TMEM_ENABLED) + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::2.sync.aligned;" ::); + #else + CUTE_INVALID_CONTROL_PATH("Attempting to use TMEM allocation PTX without CUTE_ARCH_TCGEN05_TMEM_ENABLED"); + #endif + } +}; + +} // namespace cute::TMEM diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a3883e85a45b41b36bcbdf1784668f93892ec2c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/arch/util.hpp @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#if defined(__clang__) && defined(__CUDA__) + // __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 + #if __clang_major__ >= 14 + #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 + #if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 + #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // __cvta_generic_to_shared added in CUDA 11+ + #if __CUDACC_VER_MAJOR__ >= 11 + #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in CUDA 10.2 + #if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED + #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 +#endif + +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif + +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER + #define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 +#endif + +// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need +// to provide one for NVCC +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER + extern "C" { + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. + CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); + } +#endif + +namespace cute +{ + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_HOST_DEVICE +uint32_t +cast_smem_ptr_to_uint(void const* const ptr) +{ +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + + (void) ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +namespace detail { + +// +// Wrapper for MMAOp::fma +// + +template +struct CallFMA { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return MmaOp::fma(static_cast(args)...); + } +}; + +// +// Wrapper for CopyOp::copy +// + +template +struct CallCOPY { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return CopyOp::copy(static_cast(args)...); + } +}; + +// +// Utility for exploding pointers/arrays/tensors into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence) +{ + return fn(a[I]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return fn(s[Is]..., d[Id]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence, + PtrG&& g, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); +} + +// +// Utility for exploding tuples into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence) +{ + return fn(get(a)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence) +{ + return fn(get(a)..., get(b)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence, + TupleC&& c, int_sequence) +{ + return fn(get(a)..., get(b)..., get(c)...); +} + +} // end namespace detail + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp new file mode 100644 index 0000000000000000000000000000000000000000..718b342124bbfc110da65d60af3563c60b4378b5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_atom.hpp @@ -0,0 +1,691 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_constant, cute::is_integral +#include // cute::Copy_Traits +#include // cute::TiledMMA + +namespace cute +{ + +template +struct Copy_Atom; + +template +struct Copy_Atom : Copy_Atom, CopyInternalType> +{}; + +template +struct Copy_Atom, CopyInternalType> + : Copy_Traits +{ + using Traits = Copy_Traits; + + // Bit and Thr layouts from the Copy_Traits + using ThrID = typename Traits::ThrID; + using BitLayoutSrc = typename Traits::SrcLayout; + using BitLayoutDst = typename Traits::DstLayout; + using BitLayoutRef = typename Traits::RefLayout; + + using ValType = CopyInternalType; + + using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); + + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + + static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); + static constexpr int NumValDst = size<1>(ValLayoutDst{}); + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return Copy_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor & dst) const + { + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + return copy_unpack(static_cast(*this), src, dst); + } else if constexpr (is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor && dst) const + { + return call(src, dst); + } + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor & dst) const + { + static_assert(PLayout::rank == 1, "Expected rank-1 prd tensor"); + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + Traits const& traits = static_cast(*this); + auto has_with_bool = cute::is_valid([](auto t)->void_t{}, traits); + if constexpr (has_with_bool) { + copy_unpack(traits.with(prd(Int<0>{})), src, dst); + } else { + if (prd(Int<0>{})) { copy_unpack(traits, src, dst); } + } + } else if constexpr (is_tuple::value && + is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy_if(*this, tensor<0>(prd), tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& prd, + Tensor const& src, + Tensor && dst) const + { + return call(prd, src, dst); + } +}; + +// +// A tiling of copy atoms +// + +template +struct ThrCopy; + +template coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct TiledCopy : Copy_Atom +{ + // Layout information from the CopyAtom + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset + using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset + using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset + + using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); + using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + + // Layout information for the TiledCopy + using Tiler_MN = ShapeTiler_MN; + using TiledLayout_TV = LayoutCopy_TV; + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + // where + // Thr: The logical threads within the tiled copy. + // FrgV: The values local to a COPY_ATOM Src. + // FrgX: The values tiled across COPY_ATOMs Src. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_S(STensor&& stensor) + { + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + // where + // Thr: The logical threads within the tiled copy. + // FrgV: The values local to a COPY_ATOM Dst. + // FrgX: The values tiled across COPY_ATOMs Dst. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_D(DTensor&& dtensor) + { + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + } + + // Tile a tensor or a layout from shape + // ((TileM,TileN,...), (RestM,RestN,...)) + // to shape + // (Thr,(FrgV,FrgX),(RestM,RestN,...)) + template + CUTE_HOST_DEVICE constexpr static + auto + tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + { + // Take the thrs/vals that the atom is interested in + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) + + // Transform to the trg layout + auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); + // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) + + // Transform the thrs mode from thrid to thr_idx + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) + + /// ================== + + // Transform the tile mode + auto tv_tensor = tensor.compose(thrval2mn, _); + // ((thrid,val),(RestM,RestN,...)) + + // Unfold and return + return tv_tensor(make_coord(_,_), _); + } + + // retile_S and retile_D assume they are working with the reference layout -- they are the same + template + CUTE_HOST_DEVICE constexpr static + auto + retile(Tensor&& tensor) + { + constexpr int R = remove_cvref_t::rank; + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + + // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. + // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // and that shape is what we gather from the other modes of tensor + + auto V = size<0>(tensor); + + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); + // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV + + auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + // (atom_vals,rest_vals) -> (v,m,n) + + /// ======= + + // Tile the tensor for TileFrg + auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) + + // Transform the tile mode + auto v_tensor = t_tensor.compose(frg_layout_v, _); + // ((atom_vals,rest_vals),(1,RM,RN,...)) + + // Unfold and return + return v_tensor(_, append(Int<0>{},_)); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_TV() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_TV() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_slice(ThrIdx const& thr_idx) + { + return ThrCopy(thr_idx); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } +}; + +template +struct ThrCopy +{ + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(static_cast(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + return make_tensor(static_cast(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } +}; + + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_impl(Copy_Atom const& atom, + LayoutCopy_TV const&, + Tiler const&) +{ + return TiledCopy, LayoutCopy_TV, Tiler>{atom}; +} + +// +// These tile the Copy_Atom as a whole +// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); +} + +// returns the smallest tiled copy that can retile LayoutC_TV +// for use with pipelined epilogues with subtiled stores +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + // Truncate the V-layout to just the Copy_Atom, keep the V-order + auto layoutC_TV = mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; + CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); + auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + + // Recompute tiler and restride the TV layout for the new tiler + + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); + auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + }); + + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // Apply the tiler to a reference and transform the codomain + // tile_coord -> mma_coord + auto tile2mma = composition(make_layout(mma_tiler), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2mma), layout_TV); + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from logical thread and values layouts. + * The thread and value layouts map coordinates to thr_idx and val_idx. + * The product of these layouts is taken to produce the TV layout and the Tiler. + * Useful when threads and values need very specific mappings onto coordinates + * in the target tensors. + */ +template > +CUTE_HOST_DEVICE +auto +make_tiled_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) // (m,n) -> val_idx +{ + // Take the raked_products to compute the Layout_MN + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); + +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from thread and value offset maps. + * The TV Layout maps threads and values to the codomain of the data_layout. + * It is verified that the intended codomain is valid within data_layout. + * Useful when threads and values don't care about owning specific coordinates, but + * care more about the vector-width and offsets between them. + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_cotiled_copy(Copy_Atom const& copy_atom, + AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const& data_layout) // coord -> data addr The target layout +{ + static_assert(is_static::value); + static_assert(is_static::value); + + // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 + auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + + // (tid,vid) -> data_coord + auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); + + // Check validity + // Append 1:0 to data_layout so that OOB coordinates get the stride-0 + CUTE_STATIC_ASSERT_V(coalesce(composition(make_layout(data_layout, Layout<_1,_0>{}), layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + // + // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + auto flat_data_shape = product_each(shape(data_layout)); + auto flat_data_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> data_coord + auto tile2data = composition(make_layout(flat_data_shape), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_S(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +} + +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_D(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +} + +// +// Size +// + +// The logical size of a TileCopy +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledCopy const&) +{ + return size(typename TiledCopy::Tiler_MN{}); +} + +// The number of threads involved in a TiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledCopy const&) +{ + return typename TiledCopy::TiledNumThr{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, T> const&) +{ + using Atom = Copy_Atom, T>; + print("Copy_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + print(" ValueType: "); print(sizeof_bits::value); print("b\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledCopy const& copy, char const* pad = "") +{ + using Copy = TiledCopy; + print("TiledCopy\n"); + print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); + print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); + print(static_cast(copy)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrCopy const& thr_copy) +{ + print("ThrCopy\n"); + print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); + print(TiledCopy{}); +} + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + + +// Config +#if (__CUDACC_VER_MAJOR__ >= 12) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +#endif + + +#if (!defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if (!defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED)) +# define CUTE_COPY_ATOM_TMA_SM100_ENABLED +#endif + + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +#include +#endif + + +#if defined(CUTE_COPY_ATOM_TMA_SM100_ENABLED) +#include +#endif + + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9117a1fb13de70e8a8ea74c8401e63e55164eef5 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** + * concept Copy_Traits + * { + * using ThrID = // Logical thread id (tid) -> tidx + * + * using SrcLayout = // (Logical src thread id (tid), Logical src value id (vid)) -> bit + * using DstLayout = // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit + * using RefLayout = // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit + * }; + * + * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) + * is arbitrary and only used to construct maps + * (ref-tid,ref-vid) -> (src-tid,src-vid) + * (ref-tid,ref-vid) -> (dst-tid,dst-vid) + * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to + * the Src or Dst (tid,vid) representation on demand. + * + */ + +template +struct Copy_Traits +{ + static_assert(dependent_false, "Copy_Traits not implemented for this CopyOperation."); +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +// Extract a CPY_Op from a CPY_Traits +template +struct CPY_Op {}; + +template +struct CPY_Op> { + using type = CPY_Op_Arg; +}; + +// +// Generic copy_unpack for common argument-based Copy_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const&, + Tensor const& src, + Tensor & dst) +{ + using CopyOp = typename CPY_Op::type; + using RegistersSrc = typename CopyOp::SRegisters; + using RegistersDst = typename CopyOp::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); + + detail::explode(detail::CallCOPY{}, + rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const& traits, + Tensor const& src, + Tensor && dst) +{ + copy_unpack(traits, src, dst); +} + +namespace detail { + +template +constexpr bool is_prefetch = false; + +template +constexpr bool is_prefetch> = is_same_v; + +} // end namespace detail + + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..216d739e47dbeb95a2be1bd88e5ac5e1ded2be29 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100.hpp @@ -0,0 +1,3798 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +#include +#include +#include +#include + +#include + +namespace cute +{ +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2>>, + Stride,Stride<_1,_128,_64>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _2>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,_128>, + Stride, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,Shape <_8, _2, _2, _4>>, + Stride,Stride<_1,_128,_64,_1024>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_128, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM Traits and Utilities +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Atom; + +/** Generate a TiledCopy from a CopyAtom and a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tiled_tmem_load = make_tmem_copy(TMEM_LOAD_Operation, tmem_tensor); + * auto thr_tmem_load = tiled_tmem_load.get_slice(thread_idx); + * + * Tensor tDtC = thr_tmem_load.partition_S(tmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDgC = thr_tmem_load.partition_D(gmem_tensor); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * Tensor tDrC = make_tensor(shape(tDgD)); // (TMEM_LOAD,TMEM_LOAD_M,TMEM_LOAD_N,...) + * + * copy(tiled_tmem_load, tDtC, tDrC); // tmem -> rmem + * copy(tDrC, tDgC); // rmem -> gmem + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(Copy_Atom const& atom, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = typename Copy_Atom::Traits; + static_assert(sizeof_bits_v == sizeof_bits_v, + "Expected a CopyAtom with the same type-width as the Tensor."); + + // atom thr idx -> tmem addr 4warps where each warp points to the same position within it's own subpartition + auto atom_t_layout = Layout, Stride<_0, decltype(Int<32>{} * TMEM::DP{})>>{}; + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(atom, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_copy(CopyOp const&, + Tensor const& tmem) +{ + return make_tmem_copy(Copy_Atom{}, tmem); +} + +/** Generate a TV_Tiler from a TMEM tensor + * Example: + * Tensor gmem_tensor = ... // (M,N,...) + * Tensor tmem_tensor = ... // (M,N,...) + * auto tmem_tiler = make_tmem_warp_partitioner(tmem_tensor); + * auto warp_tiler = tmem_tiler.get_slice(warp_idx); + * + * Tensor tWtC = warp_tiler.partition(tmem_tensor); // (WARP_M,WARP_N,...) + * Tensor tWgC = warp_tiler.partition(gmem_tensor); // (WARP_M,WARP_N,...) + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_tmem_warp_partitioner(Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + + // warp idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = Layout<_4, decltype(Int<32>{} * TMEM::DP{})>{}; + + // tmem coord -> tmem addr + auto tmem_layout = tmem.layout(); + // tmem addr -> tmem coord Append 1:0 so off-the-ends get the stride-0 + auto inv_tmem_layout = make_layout(left_inverse(tmem_layout), Layout<_1,_0>{}); + + // wid -> tmem_coord + auto layout_t_tmem = composition(inv_tmem_layout, atom_t_layout); + // + // Tiler -- Find the active elements in the TMEM tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled TMEM + auto flat_tmem_shape = product_each(shape(tmem_layout)); + auto flat_tmem_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_tmem_shape, replace(flat_tmem_zeros, Int<1>{})), layout_t_tmem)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> tmem_coord + auto tile2tmem = composition(make_layout(flat_tmem_shape), tiler); + + // wid -> tile_coord + auto layout_tv = composition(left_inverse(tile2tmem), layout_t_tmem); + return make_tiler_impl(layout_tv, tiler); +} + +namespace SM100::TMEM::LOAD { + +// +// Specialized copy_unpack implementation for SM100::TMEM::LOAD instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_tmem::value, "Expected TMEM src."); + static_assert(is_rmem::value, "Expected RMEM dst."); + + using SrcType = typename TS::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(src)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected src to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(src.data()); + + using RegTypeDst = typename remove_extent::type; + Tensor rD = recast(dst); + + constexpr int RegNumDst = extent::value; + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this CopyOp."); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + &tmem_addr, seq<0>{}, + rD, make_seq{}); +} + +} // end namespace SM100::TMEM::LOAD + +namespace SM100::TMEM::STORE { + +// +// Specialized copy_unpack implementation for SM100::TMEM::STORE instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected RMEM src."); + static_assert(is_tmem::value, "Expected TMEM dst."); + + using RegTypeSrc = typename remove_extent::type; + Tensor rS = recast(src); + + constexpr int RegNumSrc = extent::value; + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + + using DstType = typename TD::value_type; + CUTE_STATIC_ASSERT_V((coalesce(layout(dst)) == coalesce(upcast::value>(typename Copy_Traits::ValID{}))), + "Expected dst to have the specific TMEM layout required by CopyOp."); + + uint32_t tmem_addr = raw_pointer_cast(dst.data()); + + // thread idx <=> DP lane assert. + // ASSERT TMEM_LOAD thread attemping to access DP lane within sub-partition. +#if defined(__CUDA_ARCH__) && !defined(NDEBUG) + assert(((uint32_t(threadIdx.x) / 32) % 4) == (((tmem_addr >> 16) / 32) % 4)); +#endif + + detail::explode(CopyOp::copy, + rS, make_seq{}, + &tmem_addr, seq<0>{}); +} + +} // end namespace SM100::TMEM::STORE + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_LOAD Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + // Logical bit id to bit idx (address) + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride< _0, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2>>, + Stride,Stride< _1,_2048>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _2>>, + Stride,Stride< _1,_4096,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _4>>, + Stride,Stride< _1,_8192,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _8>>, + Stride,Stride< _1,_16384,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _16>>, + Stride,Stride< _1,_32768,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_64, _2, _32>>, + Stride,Stride< _1,_65536,_256>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_1024>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _2>>, + Stride,Stride< _1,_2048,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _4>>, + Stride,Stride< _1,_4096,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _8>>, + Stride,Stride< _1,_8192,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _16>>, + Stride,Stride< _1,_16384,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _32>>, + Stride,Stride< _1,_32768,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2, _64>>, + Stride,Stride< _1,_65536,_128>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _4>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32, _8>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_16>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_32>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_64>>, + Stride,Stride< _1,_64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,Shape <_32,_128>>, + Stride,Stride< _1, _64>>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_32>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_64>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_128>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_256>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_512>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_1024>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_2048>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _16>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout,_4096>, + Stride, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_32, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_64, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_128, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_256, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_512, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_1024, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_2048, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, + Stride< _1,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_32>; + using ValID = Layout, _32>, + Stride,TMEM::DP_b>>; + using SrcLayout = Layout, + Stride< _0, _1>>; + using DstLayout = Layout, + Stride<_4096, _1>>; + using RefLayout = SrcLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TMEM_STORE Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp256b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp128b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp64b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_16dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b1x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b2x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b4x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b8x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b16x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b32x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b64x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::STORE::SM100_TMEM_STORE_32dp32b128x_16b; + +template <> +struct Copy_Traits +{ + using ThrID = typename Copy_Traits::ThrID; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = typename Copy_Traits::DstLayout; + using DstLayout = typename Copy_Traits::SrcLayout; + using RefLayout = typename Copy_Traits::RefLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMEM { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Given a 1x tmem copy op, returns the widest repeated variant that divides the specified bits in the N-mode +template +CUTE_HOST_DEVICE constexpr +auto +op_repeater() +{ + if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_LOAD_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_LOAD_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_LOAD_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_LOAD_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_LOAD_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_LOAD_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_LOAD_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_LOAD_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_LOAD_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_LOAD_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_LOAD_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_LOAD_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_LOAD_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_LOAD_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_LOAD_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_LOAD_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_LOAD_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_LOAD_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_LOAD_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_LOAD_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_LOAD_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_LOAD_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_LOAD_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_LOAD_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_LOAD_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_LOAD_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_LOAD_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_LOAD_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_LOAD_32dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (256 * 32) == 0) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (bits_n % (256 * 16) == 0) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (bits_n % (256 * 8) == 0) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (bits_n % (256 * 4) == 0) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (bits_n % (256 * 2) == 0) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (bits_n % (256 * 1) == 0) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (128 * 64) == 0) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (bits_n % (128 * 32) == 0) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (bits_n % (128 * 16) == 0) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (bits_n % (128 * 8) == 0) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (bits_n % (128 * 4) == 0) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (bits_n % (128 * 2) == 0) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (bits_n % (128 * 1) == 0) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (64 * 128) == 0) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (bits_n % (64 * 64) == 0) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (bits_n % (64 * 32) == 0) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (bits_n % (64 * 16) == 0) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (bits_n % (64 * 8) == 0) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (bits_n % (64 * 4) == 0) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (bits_n % (64 * 2) == 0) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (bits_n % (64 * 1) == 0) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + } + else if constexpr (cute::is_same_v) { + if constexpr (bits_n % (32 * 128) == 0) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else if constexpr (bits_n % (32 * 64) == 0) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (bits_n % (32 * 32) == 0) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (bits_n % (32 * 16) == 0) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (bits_n % (32 * 8) == 0) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (bits_n % (32 * 4) == 0) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (bits_n % (32 * 2) == 0) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (bits_n % (32 * 1) == 0) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + } + else { + static_assert(dependent_false, "Must pass 1x tmem copy operator"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Select TMEM store corresponding to the provided TMEM load +template +CUTE_HOST_DEVICE constexpr auto +tmem_load_to_store(CopyOp) { + if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp256b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp128b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp64b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_16dp32b128x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b1x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b2x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b4x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b8x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b16x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b32x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b64x_16b{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x{}; + } + else if constexpr (is_same_v) { + return SM100_TMEM_STORE_32dp32b128x_16b{}; + } + else { + static_assert(dependent_false, "No TMEM_STORE matching for provided TMEM_LOAD"); + } +} + +} // namespace TMEM + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// UTCCP Copy Traits +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM100::TMEM::UTCCP { + +// +// Specialized copy_unpack implementation for SM100::TMEM::UTCCP instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(Copy_Traits const&, + Tensor const& src, + Tensor & dst) +{ + static_assert(is_rmem::value, "Expected smem_desc src for SM100_UTCCP"); + static_assert(is_tmem::value, "Expected tmem dst for SM100_UTCCP"); + CopyOp::copy(src[0], raw_pointer_cast(dst.data())); +} + +} // end namespace SM100::TMEM::UTCCP + +// In the following UTCCP traits, the ValID is representing: +// logical_bit_idx -> tmem_addr_offset. +// And the logical_bit_idx is numbered in the order of: +// [core_matrix_strided, core_matrix_leading, broadcast, repeat]. +// The first two modes provide convenience for smem_desc construtction. +// The last two modes provide boradcast transformation for 4x32DP and 2x64DP. +// With above, the strides of first two modes are neccessary to be TMEM::DP_b and 1. +// And the stride of the third mode in the SrcLayout must be zero. + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_1cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp256bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_1cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_128dp128bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout, + Stride<_0, _1>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_1cta; + +template <> +struct Copy_Traits +{ + /* + 4DP is really hard to model if we consider this instruction as a "copy" instruction. + But, if we take it as "TMEM refresh" instruction, then everything goes out naturally. + 4DP utccp is designed to refresh the last 4 lanes of each tmem subpartition. + So, in the kernel implementation, we usually only don't need to iterate on MMA_M dimension, + but only need to iterate on MMA_K dimension. + And in each refresh, logically we are refreshing MMA's 128 rows M + 256bit K. + So the "atom_v" should be (refresh_m, refresh_k) instead of (copy_m, copy_k). + And the Src/DstLayout below is: copy_bits -> logical_refresh_bits. + */ + + using ThrID = Layout<_1>; + using ValID = Layout, + Stride>; + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4dp256bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using DstLayout = Layout>, + Stride<_0,Stride<_32,_128>>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx32>>; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_4x32dp128bit_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, broadcast] + using ValID = Layout, + Stride<_DP,_1, _DPx64>>; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0213_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _64, _0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_1cta; + +template <> +struct Copy_Traits +{ + using _DP = TMEM::DP_b; + using _DPx32 = Int<_DP{}*32>; + using _DPx64 = Int<_DP{}*64>; + + using ThrID = Layout<_1>; + // logical bit_idx -> tmem_addr + // [core_matrix_strided, core_matrix_leading, repeat, broadcast] + using ValID = Layout, + Stride<_DP,_1 ,_DPx64,_DPx32>>; + + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32,_4096,_0>>>; + using DstLayout = Layout, + Stride<_0, _1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +using SM100::TMEM::UTCCP::SM100_UTCCP_2x64dp128bitlw0123_2cta; + +template <> +struct Copy_Traits +{ + using ThrID = Layout<_2>; + using ValID = typename Copy_Traits::ValID; + using SrcLayout = Layout>, + Stride<_0,Stride<_1, _32, _4096,_0>>>; + using DstLayout = Layout, + Stride<_0,_1>>; + using RefLayout = DstLayout; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr +auto +make_utccp_copy(CopyOp const&, + Tensor const& tmem) +{ + static_assert(is_tmem::value, "Expected TMEM tensor."); + using T = typename TEngine::value_type; + using Traits = Copy_Traits; + using Atom = Copy_Atom; + + // atom thr idx -> tmem addr This is the T in the Layout_TV + auto atom_t_layout = make_layout(size(typename Traits::ThrID{}), Int<0>{}); + // atom val idx -> tmem addr Cast the CopyOp's value ids to the proper data width + auto atom_v_layout = coalesce(upcast::value>(typename Traits::ValID{})); + + return make_cotiled_copy(Atom{}, make_layout(atom_t_layout, atom_v_layout), tmem.layout()); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cd3bf98bbe4c4bff03d1ba1ad33111e12edf3db9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_im2col.hpp @@ -0,0 +1,488 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy + +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/tensor.hpp" + +namespace cute { + +struct SM100_TMA_2SM_LOAD_IM2COL_OP : SM100_TMA_2SM_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for consistency with the actual multicast Copy_Traits + /// specialization) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_IM2COL_MULTICAST_OP : SM100_TMA_2SM_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM100 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM100_TMA_2SM_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const + { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM100_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM100 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_im2col_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M,K,...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N,K,...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K) + Cluster_Tile const& cluster_tile, // (TILE_M,TILE_N,TILE_K) + TiledMMA const& mma, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(cluster_tile), get<2>(cluster_tile)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides)); + + return detail::make_tma_copy_im2col(copy_op, gtensor, slayout, + cta_t_map, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, + lower_srt, stride_srt, aux_params); +} + +///////////////////////////////////// +// Experimental Make Im2col TMA Atom +///////////////////////////////////// + +template +CUTE_HOST +auto +make_im2col_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only MK modes from MNK + auto cluster_tile_shape = append(make_shape(get<0>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} + +template +CUTE_HOST +auto +make_im2col_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape, // (CTA_V, CTA_M, CTA_N, CTA_K) + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, + TMA::DescriptorAuxParams const& aux_params = {}) +{ + constexpr int R = GLayout::rank; + // Keep only NK modes from MNK + auto cluster_tile_shape = append(make_shape(get<1>(mma_tiler), get<2>(mma_tiler)), Int<1>{}); + auto cluster_layout = make_identity_layout(cluster_tile_shape); + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(cluster_layout))(_, repeat(_)); + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + return detail::make_tma_atom_im2col(copy_op, gtensor, slayout, num_multicast, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, + stride_whd, lower_srt, stride_srt, aux_params); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8621e8a98239184aaa44a9881d6f4548bf840222 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm100_tma.hpp @@ -0,0 +1,501 @@ +/*************************************************************************************************** + * Copyright (c) 2021 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + + +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////// +////////////////////////////// TMA_LOAD //////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_OP : SM100_TMA_2SM_LOAD {}; + +// The non-executable SM100_TMA_2SM_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {&tma_desc_, &tma_mbar, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {{}, {new_tma_desc, &tma_mbar, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM100_TMA_2SM_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM100_TMA_2SM_LOAD_MULTICAST_OP : SM100_TMA_2SM_LOAD_MULTICAST {}; + +template +struct Copy_Traits +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + // Construct an executable SM100_TMA_2SM_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm100 const& cache_hint = TMA::CacheHintSm100::EVICT_NORMAL) const { + return {{}, {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}}; + } + + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM100_TMA_2SM_LOAD_MULTICAST_OP before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_2>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM100_TMA_2SM_LOAD_MULTICAST_OP arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +//////////////////////////////////// +// Make TMA +/////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM100_TMA_2SM_LOAD + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cluster_tile The Cluster-local tile that each Cluster will be tiling GMEM with. + * This is often the cluster_tile_shape that is used to tile the GMEM: + * local_tile(gtensor, cluster_tile_shape, cluster_coord) + * -> Cluster-local tile of GMEM + * @param mma The TiledMMA that defines the Cluster-Tile to Block-Tile partitioning. + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reordering the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + */ +template +CUTE_HOST +auto +make_tma_copy_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILER_M, TILER_N, TILER_K, ...) + TiledMMA const& mma) +{ + // Keep only MK modes from MNK + auto cluster_tiler_mk = remove<1>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mk); // (TILE_M, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_0,_1,_0>{}; // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only NK modes from MNK + auto cluster_tiler_nk = remove<0>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + + auto cta_t_vmnk_strides = [](){ + if constexpr (is_same_v || + is_same_v) { + return Stride<_0,_1,_0,_0>{}; // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Stride<_0,_0,_0,_0>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + auto cta_t_shape = shape(mma.get_thr_layout_vmnk()); + // cta rank -> logical cta idx + auto cta_t_map = coalesce(make_layout(cta_t_shape, compact_col_major(cta_t_shape, cta_t_vmnk_strides))); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_copy_C_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, N, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_N, ...) + Cluster_Tiler const& cluster_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma) +{ + // Keep only MN modes from MNK + auto cluster_tiler_mn = remove<2>(cluster_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(cluster_tiler_mn); // (TILE_M, TILE_N, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_C(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_N, ...) + + static_assert(is_same_v || + is_same_v || + is_same_v, + "Unsupported TMA Op, expected a non-multicast TMA"); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); +} + +//////////////////////////////////// +// Experimental Make TMA Atom +/////////////////////////////////// + +template +CUTE_HOST +auto +make_tma_atom_A_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (M, K, ...) + SLayout const& slayout, // (MMA, MMA_M, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only MK modes from MNK + auto mma_tiler_mk = remove<1>(mma_tiler); + + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_mk); // (TILE_M, TILE_K, ...) + + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_A(g_tile))(_, repeat(_)); // (MMA, MMA_M, MMA_K, ...) + +#if 0 + print("(tma_a) slayout: "); print(slayout); print("\n"); + print("(tma_a) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_a) g_tile: "); print(g_tile); print("\n"); + print("(tma_a) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_a) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<2>(cluster_shape); // VMNK: Use only the N-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +template +CUTE_HOST +auto +make_tma_atom_B_sm100(CopyOp const& copy_op, + Tensor const& gtensor, // (N, K, ...) + SLayout const& slayout, // (MMA, MMA_N, MMA_K, ...) + MMA_Tiler const& mma_tiler, // (TILE_M, TILE_N, TILE_K, ...) + TiledMMA const& mma, + ClusterShapeVMNK const& cluster_shape) // (CTA_V, CTA_M, CTA_N, CTA_K) +{ + // Keep only NK modes from MNK + auto mma_tiler_nk = remove<0>(mma_tiler); + // cluster tile coord -> gtensor coord + auto g_tile = make_identity_layout(shape(gtensor)).compose(mma_tiler_nk); // (TILE_N, TILE_K, ...) + // cta val idx -> gmem mode + auto cta_v_tile = layout<1>(mma.thrfrg_B(g_tile))(_, repeat(_)); // (MMA, MMA_N, MMA_K, ...) + +#if 0 + print("(tma_b) slayout: "); print(slayout); print("\n"); + print("(tma_b) mma_tiler_nk: "); print(mma_tiler_nk); print("\n"); + print("(tma_b) g_tile: "); print(g_tile); print("\n"); + print("(tma_b) mma_tiler: "); print(mma_tiler); print("\n"); + print("(tma_b) cta_v_tile: "); print(cta_v_tile); print("\n"); +#endif + + // The size of the multicasting + auto num_multicast = [&](){ + if constexpr (is_same_v || + is_same_v) { + return size<1>(cluster_shape); // VMNK: Use only the M-CTAs in the Multicast + } else + if constexpr (is_same_v || + is_same_v || + is_same_v) { + return Int<1>{}; // VMNK: Use no CTAs in Non-Multicast + } else { + static_assert(dependent_false, "Unsupported TMA"); + } + }(); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, gtensor, slayout, num_multicast, cta_v_tile); +} + +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5299894bfcb8e7d8734bf98a497e67e3abf6dca6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm50.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1, _64>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_32, _2>>, + Stride,Stride< _1, _256>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c43b7483ba28d0f250a9e317f805af14e1dc90f6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm75.hpp @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2>>, + Stride,Stride< _1,_128>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _2>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _4>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_32, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ab8d128c291ec79fc5ccc5f6aab721b647aea150 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm80.hpp @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEALWAYS_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ad479df0606a3d23fbb7e125c052ece203282803 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4d1e3ffff293e0e6c1fdbf286c743f5d60542d0 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -0,0 +1,940 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/tensor.hpp" + +#include "cute/algorithm/prefetch.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cute +{ + +// Utility for unpacking TMA_LOAD_IM2COL arguments into a CopyOp +template +struct TMA_LOAD_IM2COL_Unpack +{ + /// Copy from src to dst. + /// + /// @param traits Copy traits created with a TMA descriptor that + /// correctly matches the input tensor and other convolution + /// parameters. + /// + /// @param src Tile of the im2col-transformed coordinate tensor + /// (result of get_tma_tensor), representing the global-memory + /// tensor from which to load. + /// + /// @param dst Shared memory tile, into which to load. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, // tile of the transformed global activation (A) tensor + Tensor & dst) // shared memory tile + { + auto src_coord_offset = src(Int<0>{}); + auto src_coord_cwhdn_offset_srt = flatten(src_coord_offset); + // Interpret the TMA IM2COL coordinate as (c, ([w,h,d]), n, ([s,r,t])) + CUTE_STATIC_ASSERT_V(rank(src_coord_offset) == _4{}); + CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); + + if constexpr (detail::is_prefetch) { + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } else { + static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } + } +}; + +// Copy_Traits for SM90 im2col TMA load comes in two layers. +// +// 1. Copy_Traits +// 2. Copy_Traits +// +// Copy_Traits +// is the "outer" layer. It has a TMA descriptor, +// but no barrier ("tma_mbar"), so it's "nonexecutable." +// One calls its "with" member function with a barrier, +// to get an executable "inner"-layer +// Copy_Traits object. +// That object's "copy_unpack" member function +// actually invokes im2col TMA load. + +struct SM90_TMA_LOAD_IM2COL_OP : SM90_TMA_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for interface compatibility with the actual multicast Copy_Traits) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 im2col +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL::PREFETCH arguments + tuple const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_OP : SM90_TMA_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE IM2COL//////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE_IM2COL with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE_IM2COL arguments + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE_IM2COL"); + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); + + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +namespace detail { + +/// @brief Creates a TMA descriptor for im2col TMA load. +/// +/// @param tensor_cwhdn Global activation tensor (A matrix of Fprop). +/// This is the original (not im2col-transformed) tensor in global +/// memory. +/// +/// @param slayout Rank 2 (M,K) shared memory layout of the activation +/// tensor. Here, K is "GEMM K," not the filter tensor's mode of +/// the same name. +////// +/// @param traversal_stride Traversal strides convolution parameter +////// +/// Each of padding_shape, traversal_stride, and dilation_shape is a +/// tuple whose size is the number of spatial modes (e.g., 3 for a 5-D +/// convolution). +/// +/// @return TMA descriptor for im2col TMA load +template +CUTE_HOST +auto +make_im2col_tma_copy_desc( + Tensor const& tensor_cwhdn, // (C,W,H,D,N) + uint32_t range_c, // TILE_C + uint32_t range_whdn, // TILE_WHDN + SmemSwizzle const& smem_swizzle, // Swizzle + TMALayout const& tma_layout_vt, // TMA layout + LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" + UpperCornerStride const& upper_corner_whd, // WHD upper corner + LowerPaddingStride const& lower_padding_whd, // WHD lower padding + UpperPaddingStride const& upper_padding_whd, // WHD upper padding + TraversalStride const& stride_whd, // WHD traversal stride + LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" + DilationStride const& stride_srt, // SRT stride - dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + static_assert(is_gmem::value, "Tensor must point to GPU global memory."); + using value_type = typename EngineA::value_type; + + constexpr uint32_t num_total_modes = LayoutA::rank; + constexpr int num_spatial_modes = num_total_modes - 2; + + // Gmem starting address + void* gmem_address = (void*) raw_pointer_cast(tensor_cwhdn.data()); + + // Gmem extents are just the tensor shape + cute::array gmem_prob_shape = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_shape[i] = static_cast(shape(tensor_cwhdn)); + }); + + // Gmem strides are byte strides of the activation tensor in CWHDN order + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_stride[i] = sizeof(value_type) * stride(tensor_cwhdn); + }); + + // Traversal strides are a function of the dilation shape + // corresponding to spatial (WHD) modes. + cute::array tma_traversal_strides = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + tma_traversal_strides[i+1] = static_cast(get(stride_whd)); + }); + + cute::array tma_lower_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_lower_corner[i] = static_cast(get(lower_corner_whd)); + }); + + cute::array tma_upper_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_upper_corner[i] = static_cast(get(upper_corner_whd)); + }); + + Im2ColTmaDescriptor tma_desc; + +#if (__CUDACC_VER_MAJOR__ >= 12) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); + CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); + TMA::SmemSwizzleBits swizzle_bits = detail::get_tma_swizzle_bits(smem_swizzle); + TMA::SmemSwizzleBase swizzle_base = detail::get_tma_swizzle_base(smem_swizzle); + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + + CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( + &tma_desc, + tma_format, + num_total_modes, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly sizeof(value_type) + tma_lower_corner.data(), + tma_upper_corner.data(), + range_c, + range_whdn, + tma_traversal_strides.data(), + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oob_fill); + + // The extra asserts help indicate the error's cause. + assert(encode_result != CUDA_ERROR_DEINITIALIZED); + assert(encode_result != CUDA_ERROR_NOT_INITIALIZED); + assert(encode_result != CUDA_ERROR_INVALID_CONTEXT); + assert(encode_result != CUDA_ERROR_INVALID_VALUE); + assert(encode_result == CUDA_SUCCESS); + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + // + // Calculate gemm shapes and linearized shapes based on tma layout tiling. + // + + // Compute [w, h, d, n] + // q/p/z = (w/h/d + (upper_corner_whd - lower_corner_whd - 1)) / stride_whd + 1 + auto gemm_mn_ = cute::transform(cute::make_seq{}, [&](auto i) { + return (shape(tensor_cwhdn) + get(upper_corner_whd) - get(lower_corner_whd) - Int<1>{}) / get(stride_whd) + Int<1>{}; + }); + auto gemm_mn = append(gemm_mn_, shape(tensor_cwhdn)); + + // Compute [c, s, r, t] + // fprop/wgrad, s/r/t = 1 + (upper_padding_whd - upper_corner_whd) / stride_srt + // wgrad, s/r/t = 1 + (lower_padding_whd - lower_corner_whd) / stride_srt + auto gemm_k_ = cute::transform(cute::make_seq{}, [&](auto i) { + auto padding_size = conditional_return(get(stride_srt) > Int<0>{}, + get(upper_padding_whd) - get(upper_corner_whd), + get(lower_corner_whd) - get(lower_padding_whd)); + return Int<1>{} + padding_size / get(stride_srt); + }); + auto gemm_k = prepend(gemm_k_, shape<0>(tensor_cwhdn)); + + // For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t)) + // For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n)) + auto gemm_shapes_common = make_shape( + transform_leaf(gemm_mn, [](auto s) { + return conditional_return(cute::is_static{}, s, cutlass::FastDivmod(s)); + }), + gemm_k); + auto gemm_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common)); + + // For fprop/dgrad kernel, linearized shapes is (whdn, (c, s, r, t)) + // For wgrad kernel linearized shapes is ((c, s, r, t), whdn) + auto linear_shapes_common = make_shape(size(gemm_mn), gemm_k); + auto linear_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), linear_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), linear_shapes_common)); + + // + // Calculate gmem basis stride based on tma layout tiling. + // + + auto tma_basis_scale = make_shape(Int<1>{}, stride_whd, Int<1>{}, stride_srt); + auto tma_basis = elem_scale(tma_basis_scale, make_basis_like(tma_basis_scale)); + + auto gbasis_strides_common = make_stride( + append(get<1>(tma_basis), get<2>(tma_basis)), + prepend(get<3>(tma_basis), get<0>(tma_basis))); // ((w,h,d,n),(c,s,r,t)) + auto gbasis_strides = make_stride( + basis_get(stride<0,1>(tma_layout_vt), gbasis_strides_common), + basis_get(stride<0,0>(tma_layout_vt), gbasis_strides_common)); + + // + // Create tma tensor + // + + auto lower_corner = make_arithmetic_tuple(Int<0>{}, lower_corner_whd, Int<0>{}, lower_srt); + + auto tensor_multimode = make_tensor(ArithmeticTupleIterator(lower_corner), gemm_shapes, gbasis_strides); + auto tensor_linear = make_identity_tensor(linear_shapes); + auto tma_tensor = make_tensor(tensor_multimode.data(), composition( + tensor_multimode.layout(), + tensor_linear(Int<0>{}), + tensor_linear.layout())); + + return cute::make_tuple(tma_desc, tma_tensor); +} + +template +CUTE_HOST_RTC +auto +make_tma_atom_im2col(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); + // IM2COL uses a maximum of 2 modes + constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, safe_div(size(tma_layout_trunc), num_multicast)); + +#if 0 + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); +#endif + + auto range_c = size<0,0>(tma_layout_vt); + auto range_whdn = size<0,1>(tma_layout_vt); + Tensor gtensor_cwhdn = make_tensor(gtensor.data(), + flatten(make_layout(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,0>(tma_layout_vt), gtensor.stride())), + make_layout(basis_get(stride<0,1>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.stride()))))); + auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( + gtensor_cwhdn, + range_c, + range_whdn, + get_swizzle_portion(slayout), + tma_layout_vt, + lower_corner_whd, + upper_corner_whd, + lower_padding_whd, + upper_padding_whd, + stride_whd, + lower_srt, + stride_srt, + aux_params); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof(T) * 8; + + using Traits = Copy_Traits, decltype(tma_tensor)>; + using Atom = Copy_Atom; + +#if 0 + print("num_bits : "); print(num_bits_per_tma); print("\n"); +#endif + + Traits tma_traits{tma_desc, tma_tensor}; + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +/// Make a TiledCopy for im2col TMA load. +/// +/// @param copy_op The copy implementation: either +/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. +/// +/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. +/// For Fprop convolutions, this is the activation tensor. This is +/// the "original tensor that points to global memory, not the +/// coordinate (im2col-transformed) tensor. +/// +/// @param slayout Layout of shared memory tile. +/// +/// @param stride_whd The traversal strides convolution +/// parameter. +/// +/// @return TiledCopy specialization for im2col TMA loads. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, + lower_corner_whd, upper_corner_whd, lower_padding_whd, + upper_padding_whd, stride_whd, lower_srt, stride_srt, aux_params); + + // + // Construct the TiledCopy + // + + auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +/// Make a TiledCopy for im2col TMA with no offsets. +/// E.g. im2col TMA load for C and im2col TMA store for D. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map) // CTA vid -> gmem coord +{ + constexpr int num_spatial_modes = rank<0>(GLayout{}) - 1; + return make_tma_copy_im2col(copy_op, gtensor, slayout, cta_t_map, cta_v_map, + append(Stride<_0>{}, Int<0>{}), // lower_corner_whd + append(Stride<_0>{}, Int<0>{}), // upper_corner_whd + append(Stride<_0>{}, Int<0>{}), // lower_padding_whd + append(Stride<_0>{}, Int<0>{}), // upper_padding_whd + append(Stride<_1>{}, Int<1>{}), // stride_whd + append(Stride<_0>{}, Int<0>{}), // lower_srt + append(Stride<_1>{}, Int<1>{})); // stride_srt +} + +} // namespace detail + + + +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, + slayout, cta_t_tile, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// No offsets copy. +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, slayout, cta_t_tile, cta_v_tile); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}); +} + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3be30af795241037a47cad0f41b5e46da3108382 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma.hpp @@ -0,0 +1,1593 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include + +#include + +#include + +#include + +namespace cute +{ + +template +struct AuxTmaParams { + using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic + GmemStrides g_stride_; + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static + static_assert(is_static::value); + using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle + static_assert(is_static::value); +}; + +// Utility for unpacking TMA_LOAD arguments into a CopyOp +template +struct TMA_LOAD_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + + auto src_coord = src.data().coord_; + void* dst_ptr = cute::raw_pointer_cast(dst.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; + + // Construct with updated TMA descriptor only (no barrier change) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {*new_tma_desc, aux_params_}; + } +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +// The prefetch for SM90_TMA_LOAD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD::PREFETCH arguments + tuple const opargs_; + + // Construct with any other Traits' TMA Desc + template + CUTE_HOST_DEVICE + Copy_Traits(OtherTraits const& traits) + : opargs_({traits.get_tma_descriptor()}) {} + + // Construct directly with a TMA descriptor pointer + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc) + : opargs_({desc}) {} + + // Build a new Prefetch traits with a different TMA descriptor pointer + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) + : opargs_(desc, mbar, mask, hint) {} + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return get<0>(opargs_); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {}; + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor const* tma_desc_; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_REDUCE_ADD ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_REDUCE_ADD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_REDUCE_ADD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_coord_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const src_ptr, + Coord const& dst_coord, seq) const + { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + + SM90_TMA_REDUCE_ADD::copy(&tma_desc_, + src_ptr, get(dst_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_REDUCE_ADD"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor + + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// BULK COPY ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_BULK_COPY_G2S arguments + // 0: uint64_t* bulk_load_memory_barrier + cute::tuple bulk_load_mbar_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_same, cute::tuple>::value, + "Extra arguments not set. Set .with() before use."); + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); + static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); + SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_), + raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits + : Copy_Traits +{ + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) {} + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_PREFETCH"); + SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); + static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); + SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +// +// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior +// + +template +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_UBULK_COPY arguments + // 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S] + cute::tuple opargs_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } +}; + +// +// MAKE_TMA_COPY and related +// + +namespace detail { + +// Custom version of coalesce that greedily combines modes only up to size-256 +// Look at each element and the back of the stack (in order of priority) +// back(NewLayout) get(OldLayout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_back s1:d1 +// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256 +// s0:d0 s1:d1 => append s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == rank_v) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return coalesce_256_impl(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return coalesce_256_impl(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_stride) && + get(old_shape) * back(new_shape) <= Int<256>{})>::value) { + // Merge modes because the shapes and strides match and the merge is 256 or less + return coalesce_256_impl(old_shape, old_stride, + replace_back(new_shape, get(old_shape) * back(new_shape)), + new_stride); + } else { + // Can't replace or merge, so append a new mode + return coalesce_256_impl(old_shape, old_stride, + append(new_shape, get(old_shape)), + append(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +construct_tma_gbasis(Tensor const& gtensor, // The original GMEM Tensor + Layout const& slayout, // The layout of SMEM + Layout const& cta_v_map) // smem_idx to hier gmode +{ + // + // TMA parameter checking + // + + // CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + // "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) == size(cta_v_map), + "TMA requires CTA_Tile and SLayout top-level size equivalence."); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("slayout : "); print(slayout); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); +#endif + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode + auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("inv_smem_layout : "); print(inv_smem_layout); print("\n"); + print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n"); +#endif + + // + // TMA gtensor truncation + // + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + + // Keep only the static-1 basis modes into gmem + auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full); + +#if 0 + print("smem_rank : "); print(smem_rank); print("\n"); + print("sidx2gmode : "); print(sidx2gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // The smem vector is the same units as gtensor, so compose first and then recast + // tma_val_idx:gmem_strides + auto tile_gstride = recast(gtensor.compose(sidx2gmode)).layout(); + // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) + // tma_box_shape:gmem_strides + auto tma_gstride = coalesce_256(tile_gstride); + + // Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes + auto gbasis = make_identity_layout(shape(gtensor)); + auto tile_gbasis_tmp = gbasis.compose(sidx2gmode); + + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape + // tma_box_shape:gmem_mode + auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); + + // "Coalesce" the tile basis into a compatible shape with the tma_gstride + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + // Find missing bases that don't appear in tile_gbasis + auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) + { + if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { + return cute::tuple<>{}; // If size-1 or stride-0, then don't append + } else { + using E = decltype(e); + auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; }); + if constexpr (decltype(has_e)::value) { + return cute::tuple<>{}; // If d was found, then don't append + } else { + return cute::tuple(e); // Else, this is missing so append + } + } + }); + + // Append the remaining basis modes that contribute to the TMA with size-1 + auto tile_gbasis_remaining_shape = repeat(Int<1>{}); + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )), + tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); + + // Group the trailing modes to make this max rank-5 -- TMA rank limitation + // tma_box_shape:gmem_mode + auto tma_gbasis = group(tma_gbasis_full); + +#if 0 + print("tile_gstride : "); print(tile_gstride); print("\n"); + print("tma_gstride : "); print(tma_gstride); print("\n"); + print("gbasis : "); print(gbasis); print("\n"); + print("tile_gbasis : "); print(tma_gbasis_tile); print("\n"); + print("tma_gbasis : "); print(tma_gbasis); print("\n"); +#endif + + return tma_gbasis; +} + +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Tensor const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType + TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s) + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + static_assert(is_tuple::value); + static_assert(is_same::value || is_same::value); + + using TmaInternalType = typename GEngine::value_type; + constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value; + static_assert(TmaRank >= tma_rank); + + auto gmem_shape = shape(gtensor); + auto gmem_stride = stride(gtensor); + // Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + constexpr int tma_i_rank = decltype(rank(tma_gbasis_stride))::value; + if constexpr (tma_i_rank == 1) { + // Trivial contribution of this gmem mode to this tma mode + auto ej = unwrap(get(tma_gbasis_stride)); + gmem_prob_shape[i] = basis_get(ej, gmem_shape); + gmem_prob_stride[i] = basis_get(ej, gmem_stride); + } else { + // Apply a recurrence to each gmem mode that contributes to this tma mode + for_each(get(tma_gbasis_stride), [&](auto ej) { + // Problem shape + uint64_t shape_j = basis_get(ej, gmem_shape); + // Problem stride (in bytes) + uint64_t stride_j = basis_get(ej, gmem_stride); + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); + } + }); +} + +// Overload for an existing Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Copy_Traits const& tma_traits, + Tensor const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}), + gmem_prob_shape, gmem_prob_stride); +} + +// Use a sidx2gmode to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template +CUTE_HOST_RTC +auto +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& tma_gbasis, // TMA mode -> GMEM mode mapping + Swizzle const& swizzle, // Swizzle fn on smem_idx + uint32_t num_multicast) // The number of CTAs in multicasting +{ + // + // TMA desc creation + // + + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; + + // + // TMA gmem desc info + // + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); + auto gmem_layout = gtensor_T.layout(); + + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + + fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride); + + assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned + + assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 + + // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). + assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem"); + + // convert strides to byte strides + for(uint64_t& stride : gmem_prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + // Assert the byte strides. Tma Descriptor uses byte strides + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + + // + // TMA smem desc info + // + + cute::array smem_box_shape = {1,1,1,1,1}; + cute::array smem_box_stride = {1,1,1,1,1}; + // The smem box is simply given by the sizes of the modes in tma_gbasis + for_each(make_seq{}, [&](auto i) { + smem_box_shape[i] *= size(tma_gbasis); + }); + // Finally, truncate the tma box by the num_multicast + for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) { + assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0); + uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]); + smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast); + multicast = new_mult; + } + + assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + + assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + + // + // Construct the descriptor + // + + TmaDescriptor tma_desc{}; + + // + // TMA general info + // + + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + TMA::SmemSwizzleBits swizzle_bits = get_tma_swizzle_bits(swizzle); + TMA::SmemSwizzleBase swizzle_base = get_tma_swizzle_base(swizzle); + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + assert(false); + } + + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + auto recast_ratio = cute::trait_ratio(sizeof_bits{}, + sizeof_bits< TmaInternalType>{}); + + auto gbasis = make_basis_like(shape(gtensor)); + + // Finally, get the inverse permutation of the E bases for the mocked gmem stride + auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) { + auto si = basis_get(ei, shape(gmem_layout)); + auto di = basis_get(ei, stride(gmem_layout)); + if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { + return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA + } else { + auto tma_gmem_basis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + using EI = decltype(ei); + [[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) { + return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(j == Int<0>{})::value) { + auto scale = recast_ratio * basis_get(ei, stride(gtensor)); + return E{} * scale; // Return TMA Coord basis -- with a recast scale factor + } else + if constexpr (decltype(rank(tma_gmem_basis_stride) == Int<1>{})::value) { + return E{}; // Return TMA Coord basis -- known scale of Int<1>{} + } else { + int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], uint64_t{16})), 8); + return E{} * scale; // Return TMA Coord basis -- with a dynamic scale factor + } + } + }); + +#if 0 + print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n"); +#endif + + using AuxParams = AuxTmaParams; + return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride}); +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_atom(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // + // TMA truncated layout + // + + auto smem_swizzle = get_swizzle_portion(slayout); + auto smem_layout = get_nonswizzle_portion(slayout); + + auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); + + // + // Construct the TMA Desc and the strides of the TMA Tensor + // + + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_gbasis, + smem_swizzle, + num_multicast); + + // + // Construct the Copy_Traits + // + + constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; + +#if 0 + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); +#endif + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +// The "logical TMA tid" is a map from the CTA rank to its logical id +// within the instruction. It works like a mask or ordering on the +// CTAs. For non-multicast TMA, all CTAs should map to 0. For +// multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. +template +CUTE_HOST_RTC +auto +make_tma_copy_tiled(CopyOp const& copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + Copy_Atom atom = make_tma_copy_atom(copy_op, gtensor, slayout, + cosize(cta_t_map), cta_v_map); + + // + // Construct the TiledCopy + // + + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), safe_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +} // end namespace detail + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler, + cluster_size); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + } +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Cluster_Size const& cluster_size) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); +} + +//////////////////////////////////// +// Experimental Make TMA Atom and Partitioner +/////////////////////////////////// + +template > +CUTE_HOST_RTC +auto +make_tma_atom(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size = {}) +{ + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, + gtensor, slayout, + size(cluster_size), cta_v_tile); +} + +// The "VectorCopy Partitioner" for TMA +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + CtaCoord const& cta_coord, + Layout const& cta_layout, // T: CTA coord -> logical multicast id + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor)); + + // Invert the smem to get the largest contiguous vector in the smem layout + Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); + // Scale that up to cover all of the smem_coords + Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor)); + + // Factor out the single-instrucion portion + Layout tma_layout_v = make_layout(Int::NumValSrc>{}); + auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); + + // Append with _ until we cover all Rest... modes + auto glayout_V = append(layout_V, _); + auto slayout_V = append(layout_V, _); + // Transform tile mode and coalesce + Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + // Offset inside the TMA-mode for the multicast + auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); + auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); + auto gcoord = append(multicast_coord, Int<0>{}); + auto scoord = append(multicast_coord, Int<0>{}); + + Tensor gresult = domain_offset(gcoord, gtensor_v); + Tensor sresult = domain_offset(scoord, stensor_v); + + return cute::make_tuple(gresult, sresult); +} + +// Explicit defaults for cta_coord and cta_layout +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor); +} + +// TMA Multicast Masks Calculation +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_vmnk, cta_layout_vmnk); + + uint16_t mcast_mask = 0; + if constexpr (rank_v == 0) { + // Trivial case with no additional ctas + mcast_mask = uint16_t(1); + } else + if constexpr (rank_v == 1 and depth_v <= 1 and + not is_static::value) { + // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout + mcast_mask = uint16_t(1); + // Smear by stride<0> (may want to predicate on stride<0> mag?) + mcast_mask |= mcast_mask << (1*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (2*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (4*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (8*stride<0>(cta_layout)); + // Select shape<0> + mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout))); + } else { + // Get the instruction code -- generic path + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); + } + } + // Shift by the instruction's elected block rank (dynamic) + mcast_mask <<= elected_cta; + return mcast_mask; +} + +// Projections multicast mask +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + return create_tma_multicast_mask(cta_layout_vmnk, replace(cta_coord_vmnk, _)); +} + +//////////////////////////////////// +// Make TMA copy A/B/C +/////////////////////////////////// + +template +CUTE_HOST_RTC +auto +make_tma_copy_A_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only MK modes from MNK + auto cta_tiler_mk = remove<1>(cta_tiler); + + // mcast along N mode for this M load, if any + auto cluster_size_n = size<1>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mk, + cluster_size_n); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mk); + auto cta_t_tile = make_layout(cluster_size_n); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_B_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only NK modes from MNK + auto cta_tiler_nk = remove<0>(cta_tiler); + + // mcast along M mode for this N load, if any + auto cluster_size_m = size<0>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_nk, + cluster_size_m); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_nk); + auto cta_t_tile = make_layout(cluster_size_m); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_C_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler) +{ + // Keep only MN modes from MNK + auto cta_tiler_mn = remove<2>(cta_tiler); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mn, + _1{}); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mn); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); + return tma_copy; + } +} +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp new file mode 100644 index 0000000000000000000000000000000000000000..90d616df2e234b1ed8425303b9b1e5309cbf1f32 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -0,0 +1,114 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/// @file copy_traits_sm90_tma_swizzle.hpp +/// @brief Functions for converting swizzle layout to TMA descriptor + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include + +namespace cute::detail { + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + if constexpr (B == 3) { return TMA::SmemSwizzleBits::B128; } + if constexpr (B == 2) { return TMA::SmemSwizzleBits::B64; } + if constexpr (B == 1) { return TMA::SmemSwizzleBits::B32; } + if constexpr (B == 0) { return TMA::SmemSwizzleBits::DISABLE; } + } else + + if constexpr (M == 5 || M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5 or 6. Unsupported layout swizzle."); + // S-condition as well? + return TMA::SmemSwizzleBits::B128; + } else + + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBase +get_tma_swizzle_base(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; + } + + else if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_32B; + } else if constexpr (M == 6) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_64B; + } + #if 1 + else { + static_assert(4 <= M && M <= 6, "Expected 128b=16B=(2^4)B to 512b=64B=(2^6)B base swizzle."); + } + #else + + else { + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + } + #endif +} + +template +TMA::SmemSwizzleBase +get_tma_swizzle_base(Layout const& layout) +{ + return get_tma_swizzle_base(get_swizzle_portion(layout)); +} + +} // namespace cute::detail diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp new file mode 100644 index 0000000000000000000000000000000000000000..57338e2a4be111e277fdd394ed97192ab2bf0a05 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_atom.hpp @@ -0,0 +1,705 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include + +namespace cute { + +template +struct MMA_Atom; + +template +struct MMA_Atom : MMA_Atom> +{}; + +template +struct MMA_Atom> + : MMA_Traits +{ + using MMA_Op = MMAOperation; + using Traits = MMA_Traits; + + // Element value types from the MMA_Traits + using ValTypeD = typename Traits::ValTypeD; + using ValTypeA = typename Traits::ValTypeA; + using ValTypeB = typename Traits::ValTypeB; + using ValTypeC = typename Traits::ValTypeC; + + // Thr-Val layouts from the MMA_Traits + using Shape_MNK = typename Traits::Shape_MNK; + using ThrID = typename Traits::ThrID; + using LayoutC_TV = typename Traits::CLayout; + using LayoutA_TV = typename Traits::ALayout; + using LayoutB_TV = typename Traits::BLayout; + + // Fragment value types from the MMA_Traits (optional, defaults to Val type) + using FrgTypeD = typename detail::FrgTypeC_or_Default::type; + using FrgTypeA = typename detail::FrgTypeA_or_Default::type; + using FrgTypeB = typename detail::FrgTypeB_or_Default::type; + using FrgTypeC = typename detail::FrgTypeC_or_Default::type; + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return MMA_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Cast, check, and call fma + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) const + { + static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); + static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); + static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); + static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); + + return mma_unpack(static_cast(*this), D, A, B, C); + } + + // Three arguments reproduces C + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor const& A, + Tensor const& B, + Tensor & C) const + { + return call(C, A, B, C); + } + + // + // make_fragment_A|B|C + // These functions are awkward as they expect already-partitioned tensors + // resulting from a previous call to partition_A|B|C + // The reasoning is that we can inspect the layout of the partitioned data + // and attempt to match it in generated fragment to promote vectorization + // when copying from partition to fragment. + // + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_C(CTensor&& ctensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN + CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); + // C is a bit special because we are after accumulators here + // The input/output type doesn't have to match the accumulator type + //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); + + // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor + return make_tensor(shape(ctensor)); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_A(ATensor&& atensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK + CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeA is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + || (sizeof_bits_v::value_type> == 4 && + (sizeof_bits_v == 4 || sizeof_bits_v == 3 || sizeof_bits_v == 2)) + , "Expecting ValTypeA type"); + return make_tensor(static_cast(atensor)); + } else { + // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout + return make_fragment_like(atensor); + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_B(BTensor&& btensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK + CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeB is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + + || (sizeof_bits_v::value_type> == 8 && + (sizeof_bits_v == 8 || sizeof_bits_v == 6 || sizeof_bits_v == 4)) + + , "Expecting ValTypeB type"); + return make_tensor(static_cast(btensor)); + } else { + // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout + return make_fragment_like(btensor); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// +// A tiling of mma atoms +// + +template +struct ThrMMA; + +// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA +// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed. +// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom. +template > +struct TiledMMA : MMA_Atom +{ + using Atom = MMA_Atom; + using AtomShape_MNK = typename MMA_Atom::Shape_MNK; + using AtomThrID = typename MMA_Atom::ThrID; + using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; + using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; + using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; + + static_assert( rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert( rank_v == 3, "TiledMMA requires rank-3 PermutationMNK"); + static_assert( is_tuple::value, "TiledMMA requires independent permutations of MNK."); + static_assert(is_static::value, "TiledMMA requires static permutations of MNK."); + + using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + ThrLayoutVMNK thr_layout_vmnk_; + + CUTE_HOST_DEVICE constexpr + TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) + : MMA_Atom(mma_atom), + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} + + CUTE_HOST_DEVICE constexpr auto + get_thr_layout_vmnk() const { + return thr_layout_vmnk_; + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_C(CTensor&& ctensor) const + { + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<1>()); + auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) + + // Tile the tensor for the Atom + auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<1>(AtomShape_MNK{}))); + auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) + + // Transform the Atom mode from (M,N) to (Thr,Val) + auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + + // Tile the tensor for the C-threads + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<2>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (M,K,...) + // to shape + // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_A(ATensor&& atensor) const + { + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (N,K,...) + // to shape + // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestN: The values tiled in N. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_B(BTensor&& btensor) const + { + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<1>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto b_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (N,K) to (Thr,Val) + auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + + return thr_tensor; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_slice(ThrIdx const& thr_idx) const + { + auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx); + return ThrMMA{*this, thr_vmnk}; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_thread_slice(ThrIdx const& thr_idx) const + { + return get_slice(thr_idx); + } + + // + // Utility for printing and visualization + // + + // The permutation applied to the MNK-mode data + template + CUTE_HOST_DEVICE constexpr + auto + permutation_mnk() const { + static_assert(0 <= I && I < 3); + auto perm = get(PermutationMNK{}); + return conditional_return(is_underscore{}, size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()), perm); + } + + // The size of the MNK-mode + template + CUTE_HOST_DEVICE constexpr + auto + tile_size_mnk() const { + static_assert(0 <= I && I < 3); + return size(permutation_mnk()); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_TV() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (M,N) + return thrfrg_C(ref_C).compose(thridx_2_thrid, _); + } + + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_TV() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (M,K) + return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_TV() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = composition(make_layout(make_shape (size(thr_layout_vmnk_), Int<1>{}), + make_stride(Int<1>{}, Int<0>{})), + right_inverse(make_layout(thr_layout_vmnk_, complement(thr_layout_vmnk_)))); + + // (thr_idx,val) -> (N,K) + return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); + } +}; + +template +struct ThrMMA : TiledMMA +{ + ThrVMNK thr_vmnk_; + + template + CUTE_HOST_DEVICE constexpr + auto + partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(static_cast(ctensor).data(), this->thrfrg_C(ctensor.layout())); + + auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(static_cast(atensor).data(), this->thrfrg_A(atensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(static_cast(btensor).data(), this->thrfrg_B(btensor.layout())); + + auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_C(CTensor&& ctensor) const + { + return TiledMMA::make_fragment_C(partition_C(ctensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_A(ATensor&& atensor) const + { + return TiledMMA::make_fragment_A(partition_A(atensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_B(BTensor&& btensor) const + { + return TiledMMA::make_fragment_B(partition_B(btensor)); + } +}; + +// +// These tile the MMA_Atom as a whole +// + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Atom const& mma_atom, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); + auto permutation_mnk = append<3>(permutations, _); + + return TiledMMA, + decltype(thr_layout_mnk), + decltype(permutation_mnk)>{mma_atom, thr_layout_mnk}; +} + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Op const&, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + // Attempt to wrap in an MMA_Atom<> and forward + return make_tiled_mma(MMA_Atom{}, thr_layout, permutations); +} + +// +// partition_fragment_C -- static context +// + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) +{ + auto dummy = make_layout(shape(shape_MN)); + auto dummy_tv = mma.thrfrg_C(dummy); + // Slice+rearrange like partition_C + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA const& mma, Shape_MN const& shapeMN) +{ + return make_tensor::FrgTypeC>(partition_shape_C(mma, shapeMN)); +} + +// partition_fragment_A and partition_fragment_B often depend on the +// layout of A and B and/or the thread_idx that is requesting the partition. +// For these reasons, they should not be used in a static context. +// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) +{ + auto dummy = make_layout(shape(shape_MK)); + auto dummy_tv = mma.thrfrg_A(dummy); + // Slice+rearrange like partition_A + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) +{ + auto dummy = make_layout(shape(shape_NK)); + auto dummy_tv = mma.thrfrg_B(dummy); + // Slice+rearrange like partition_B + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +// +// Size +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledMMA const& mma) +{ + return mma.template tile_size_mnk(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_shape(TiledMMA const& mma) +{ + return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma)); +} + +// Deprecate? +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// Alias +template +CUTE_HOST_DEVICE constexpr +auto +thr_size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(MMA_Atom> const&) +{ + using Atom = MMA_Atom>; + print("MMA_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" Shape_MNK: "); print(typename Atom::Shape_MNK{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledMMA const& mma) +{ + print("TiledMMA\n"); + print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n"); + print(" PermutationMNK: "); print(TiledPerm{}); print("\n"); + print(static_cast(mma)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrMMA const& thr_mma) +{ + print("ThrMMA\n"); + print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n"); + print(static_cast(thr_mma)); +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp new file mode 100644 index 0000000000000000000000000000000000000000..986b1ae0a02079bc4da158c8a7aabc9dc791ab1c --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits.hpp @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::Tensor +#include // cute::is_rmem +#include // cute::UniversalFMA +#include // cute::detail::explode + +namespace cute +{ + +/** + * concept MMA_Traits + * { + * using ValTypeD = // Logical D-value type + * using ValTypeA = // Logical A-value type + * using ValTypeB = // Logical B-value type + * using ValTypeC = // Logical C-value type (NOTE: Not used? Assumed == ValTypeD) + * + * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) + * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) + * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) + * + * using Shape_MNK = // Logical MxNxK shape of the MMA + * + * using ThrID = // Logical thread id (tid) -> tidx + * + * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord + * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord + * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord + * }; + */ + +template +struct MMA_Traits +{ + static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = D; + using ValTypeA = A; + using ValTypeB = B; + using ValTypeC = C; + + // Logical shape of the MMA + using Shape_MNK = Shape<_1,_1,_1>; + + // Logical thread id (tid) -> tidx + using ThrID = Layout<_1>; + + // (Logical thread id (tid), Logical value id (vid)) -> coord + + // (tid,vid) -> (m,k) + using ALayout = Layout>; + // (tid,vid) -> (n,k) + using BLayout = Layout>; + // (tid,vid) -> (m,n) + using CLayout = Layout>; +}; + +// Extract an MMA_Op from an MMA_Traits +template +struct MMA_Op {}; + +template +struct MMA_Op> { + using type = MMA_Op_Arg; +}; + +// +// Generic mma_unpack for any MMA_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using MMA_Op = typename MMA_Op::type; + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + mma_unpack(traits, D, A, B, C); +} + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; +template +struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; +template +struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; +template +struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; + +} // end namespace detail + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp new file mode 100644 index 0000000000000000000000000000000000000000..dd4d15e789510d9219ebd45e235c02e4e420dc89 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm100.hpp @@ -0,0 +1,4032 @@ +/*************************************************************************************************** + * Copyright (c) 2022 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include // cute::TMEM:: + +#include +#include // cute::GMMA:: +#include // cute::GMMA:: +#include // UTCCP smem desc + +#include + +// Check that aggregate initialization in .with() initializes all fields +#if defined(__GNUG__) +#pragma GCC diagnostic warning "-Wmissing-field-initializers" +#pragma GCC diagnostic error "-Wmissing-field-initializers" +#endif + +namespace cute { + +namespace UMMA { + +////////////////////////////////////////////////// +// Common layouts for UMMA Shared Memory // +////////////////////////////////////////////////// + +using cute::GMMA::Layout_MN_INTER_Atom; +using cute::GMMA::Layout_MN_SW32_Atom; +using cute::GMMA::Layout_MN_SW64_Atom; +using cute::GMMA::Layout_MN_SW128_Atom; +using cute::GMMA::Layout_K_INTER_Atom; +using cute::GMMA::Layout_K_SW32_Atom; +using cute::GMMA::Layout_K_SW64_Atom; +using cute::GMMA::Layout_K_SW128_Atom; + +using Layout_MN_SW128_32B_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _1024>>>; + +template +using Layout_MN_SW128_32B_Atom = decltype(upcast::value>(Layout_MN_SW128_32B_Atom_Bits{})); + +////////////////////////////////////////////////// +// Common layouts for Sparse UMMA Shared Memory // +////////////////////////////////////////////////// + +using cute::GMMA::Layout_MN_INTER_SpAtom; +using cute::GMMA::Layout_MN_SW32_SpAtom; +using cute::GMMA::Layout_MN_SW64_SpAtom; +using cute::GMMA::Layout_MN_SW128_SpAtom; +using cute::GMMA::Layout_K_INTER_SpAtom; +using cute::GMMA::Layout_K_SW32_SpAtom; +using cute::GMMA::Layout_K_SW64_SpAtom; +using cute::GMMA::Layout_K_SW128_SpAtom; + +template +using Layout_MN_SW128_32B_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_32B_Atom{}.layout_b()))>; + +// With UMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +// Tile a MN-logical layout atom to an MMA Tile Shape ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_mma_shape(LayoutAtom const& atom, MMATileShape const& mma_tile_shape, ModeOrder const& order = {}) +{ + constexpr int R = decltype(rank(mma_tile_shape))::value; + auto mn_shape = cute::tuple_cat(zip(shape<0>(mma_tile_shape), take<1,3>(mma_tile_shape)), take<3,R>(mma_tile_shape)); + auto mn_tiled = tile_to_shape(atom, mn_shape, order); // (BLK_M,BLK_N,...) + return tiled_divide(mn_tiled, product_each(shape<0>(mma_tile_shape))); // ((MMA_M,MMA_N),M_MMAs,N_MMAs,...) +} + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + if constexpr (M == 4) { + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 0: return LayoutType::SWIZZLE_NONE; + case 1: return LayoutType::SWIZZLE_32B; + case 2: return LayoutType::SWIZZLE_64B; + case 3: return LayoutType::SWIZZLE_128B; + } + } else + if constexpr (M == 5) { + static_assert(B == 2, "Expected B = 2 when M == 5. Unsupported layout swizzle."); + static_assert(S == 2, "Expected S = 2 when M == 5. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_128B_BASE32B; + } else { + static_assert(M==5, "Only 16B and 32B Atoms are supported for UMMA. Unsupported layout swizzle."); + return LayoutType::SWIZZLE_NONE; // ERROR + } +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for UMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_umma_desc // +* /////////////////////////////// +* Each UmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* LayoutType::128B_BASE32B : Swizzle<2,5,2> o smem_ptr o ((T,8,m),(4,k)):((1,T,LBO),(?T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to UMMA shape +* k : integer in [1,32] corresponding to UMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See UMMA::Layout_MN_XXX_Atom for building canonical UmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_umma_desc // +* ////////////////////////////// +* Each UmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See UMMA::Layout_K_XXX_Atom for building canonical UmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_umma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +make_umma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "UMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "UMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + SmemDescriptor desc; + desc.version_ = 1; // Set the version for blackwell + desc.lbo_mode_ = 0; // set to legacy mode by default + + // Layout type + constexpr UMMA::LayoutType LAYOUT_TYPE = UMMA::layout_type(u128_tensor); + desc.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.base_offset_ = base_offset; + + // LayoutType meta + constexpr int SwizzleAtomMNSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? 1 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_32B ? 2 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_64B ? 4 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B ? 8 : + LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 8 : -1; + + if constexpr (MajorMode == UMMA::Major::MN) + { + /* In units of uint128_t, each UmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + * LayoutType::B128_BASE32B : Swizzle<2,5,2> o smem_ptr o ((8,n),(4,k)):((1,LBO),(4,SBO)) + */ + + constexpr int SwizzleAtomKSize = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_128B_BASE32B ? 4 : 8; + + // Construct the canonical UMMA T Layout with shape ((SwizzleAtomMNSize,n),(SwizzleAtomKSize,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile>,Layout>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_MN Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = SwizzleAtomMNSize; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.stride_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_01 : stride_11; + desc.leading_byte_offset_ = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride_11 : stride_01; + } else + if constexpr (MajorMode == UMMA::Major::K) + { + /* In units of uint128_t, each UmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + * LayoutType::B128_BASE32B : Not applicable for Major-K + */ + + static_assert(LAYOUT_TYPE != UMMA::LayoutType::SWIZZLE_128B_BASE32B, "SWIZZLE_128B_BASE32B is invalid for Major-K"); + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical UMMA_K Layout: Expected MN-size multiple of 8."); + + // Construct the canonical UMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical UMMA_K Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = SwizzleAtomMNSize; + static_assert(stride_00 == expected_stride_00, "Not a canonical UMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == UMMA::LayoutType::SWIZZLE_NONE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical UMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.stride_byte_offset_ = stride_01; + desc.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != UMMA::Major::MN && MajorMode != UMMA::Major::K, "Unrecognized MajorMode!"); + } + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level UMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = SmemDescriptor; + using element_type = SmemDescriptor; + using value_type = SmemDescriptor; + + SmemDescriptor desc_; + + // Dereference returns the UmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new UmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + SmemDescriptor ret; + ret.lo = desc_.lo + uint32_t(offset); + ret.hi = desc_.hi; + return { ret }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +SmemDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator const&) { + printf("UMMA::DescriptorIterator"); +} + +// Flag for smem descriptor allocation/creation +template +struct smem_desc : DescriptorIterator {}; + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace UMMA + +// Customization point for creating a UMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Customization point for creating a UMMA::sparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as UMMA::smem_desc above. + // Only the interface validates that we are passed a sparse_ptr, which is recast away to construct + // the smem desc tensor + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a UMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(UMMA::DescriptorIterator{UMMA::make_umma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +// Special smem_desc_iter tensor entry for UTCCP copy. +template +constexpr auto get_utccp_smem_desc_tensor(Tensor const& smem_utccp_partitioned_tensor) { + using VecLayout = decltype(layout<0>(TLayout{})); + static_assert(VecLayout::rank == 2 && shape<1>(VecLayout{}) == 1, "Mismatched vec_mode tensor."); + static_assert(is_smem::value, "Expect vec_mode smem_tesnor."); + static_assert(is_static::value, "Utccp copy tensor's vec_mode should be static."); + + using value_type = typename TEngine::value_type; + using UtccpTaits = Copy_Traits; + + // UtccpTaits::ValID: logical_bit_idx -> tmem_offset. + // We arrange the logical_bit_idx in order of (core_matrix_strided, core_matrix_leading, repeat(only in 64dplw01), broadcast). + // So we only need the first two modes for src smem_tensor. + auto utccp_core_matrix_shape = take<0,2>(upcast>(typename UtccpTaits::ValID{}).shape()); + // logical_bit_idx -> smem_addr + Layout vec_v_layout = flatten(layout<0>(VecLayout{})); + Layout utccp_core_matrix_layout = vec_v_layout.with_shape(utccp_core_matrix_shape); + Tensor utccp_core_matrix_tensor = group_modes<0,2>(make_tensor(smem_utccp_partitioned_tensor.data(), utccp_core_matrix_layout)); + Tensor core_matrix_desc_tensor = make_tensor>(utccp_core_matrix_tensor); + return make_tensor(core_matrix_desc_tensor.data(), recast_layout(smem_utccp_partitioned_tensor.layout())); +} + +namespace UMMA { + +// Import TMEM constants +namespace TMEM = cute::TMEM; + +enum class TmemAllocMode { + // Default allocation mode. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms can be + // interleaved by using the top-half-subpartition and the bottom-half-subpartition. + // Full utilization of TMEM capacity. + Interleaved = 0, + // Prevents interleaving. + // If a TMEM Atom uses a half-subpartition (16DPs), then multiple atoms will not be + // interleaved. + // Required for DP-address equivalence in TMEM-A and TMEM-C allocations in UMMA_TS. + NonInterleaved = 1, + // Duplicates the TMEM allocation across subpartitions. + // E.g. UMMA_2SM_128xNx16_TS uses a "2x2 DP" TMEM Layout, but the TMEM allocation is + // actually doubled and the input data must be duplicated between the + // subpartitions [0,1]<->[2,3], i.e., each subpartition holds all columns + // of the A matrix needed for a single UMMA operation. + // For UMMA_2SM_128xNx16_TS, the distribution of the data is as follows. + // SM0: + // Subpart0 = A[0:32, 0:16], Subpart1 = A[32:64, 0:16], + // Subpart2 = A[A:32, 0:16], Subpart3 = A[32:64, 0:16] + // SM1: + // Subpart0 = A[64:96, 0:16], Subpart1 = A[96:128, 0:16], + // Subpart2 = A[64:96, 0:16], Subpart3 = A[96:128, 0:16] + Duplicated = 2, + // Duplicates the TMEM allocation across subpartitions for scale factor. + // Scale factor TMEM allocation for 4x1 data path + ScaleFactorDuplicated4by1 = 3, + // Scale factor TMEM allocation for 2x2 data path + ScaleFactorDuplicated2by2 = 4 +}; + +struct tmem_frg_base {}; + +// The UMMA Traits below have custom fragment type flags for their tmem tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct tmem_frg : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1 || N_SM == 2, "UMMA expects N_SM == 1 or N_SM == 2"); + if constexpr (N_SM == 1) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::NonInterleaved, + "UMMA_1SM only accepts Interleaved or NonInterleaved"); + static_assert(M_MMA == 64 || M_MMA == 128, "UMMA_1SM M-mode size should be 64 or 128."); + + if constexpr (M_MMA == 64) + { + // Half subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, Int>, + Stride, _128>>{}; + // tile_stride = 2 causes the tiling to "skip" the first tile in DPs + constexpr int tile_stride = TmemAlloc == UMMA::TmemAllocMode::Interleaved ? 1 : 2; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape),Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + } else + if constexpr (N_SM == 2) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved || TmemAlloc == UMMA::TmemAllocMode::Duplicated, + "UMMA_2SM only accepts Interleaved or Duplicated"); + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, "UMMA_2SM M-mode size should be 32 or 64 or 128."); + + if constexpr (M_MMA == 32) + { + static_assert(TmemAlloc == UMMA::TmemAllocMode::Interleaved, "Only TmemAllocMode::Interleaved is supported for UMMA_2SM M_MMA=32"); + // The "1x4" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _4>>, + Stride< _1,Stride< _128,_32>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Interleaved) + { + // The "2x2" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout, _2>>, + Stride< _1,Stride< _128,_64>>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + + } else + if constexpr (M_MMA == 64 && TmemAlloc == UMMA::TmemAllocMode::Duplicated) + { + // The "2x2" duplicated layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // The "4x1" layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_1sm = tmem_frg; +template +using tmem_frg_2sm = tmem_frg; + +// Make metadata TMEM fragments for sparse MMAs. +// Also note that the TMEM fragment addresses are assumed to be COL-4 aligned -- working with arch to remove this condition +template +struct tmem_e_frg : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128, "Only 128 implemented right now."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + [[maybe_unused]] Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32: 128x16 atom + { + static_assert(N_MMA == 16); + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 16) // FP16: 128x32 atom + { + static_assert(N_MMA == 32); + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } else + if constexpr (sizeof_bits::value == 8) // S8|Mix.F4/F6/F8: 128x64 atom + { + // For Mix 8bit f4/f6/f8, will pass in ValueType = uint8_t + static_assert(N_MMA == 64); + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + if constexpr (sizeof_bits::value == 4) // F4: 128x128 atom + { + // For F4, will pass in ValueType = fp4 + Layout tmem_restride1 = Layout>, + Stride, _1>>{}; + // F4 has roughly same TMEM layout as Mix8bit.F4/F6/F8, the only difference is that K is multiplied by two + static_assert(N_MMA == 128); + Layout tmem_atom = Layout, + Stride< _1, _128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride1, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_e_frg_ws : tmem_frg_base +{ + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + static_assert(M_MMA == 128 || M_MMA == 64 || M_MMA == 32, "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + Layout tmem_restride = Layout, + Stride>{}; + + if constexpr (sizeof_bits::value == 32) // TF32 + { + // MMA_M x MMA_K: 128x16 atom / 64x16 atom / 32x16 atom + static_assert(N_MMA == 16); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _8,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _8,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _8,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations with upcast<2> for 2-bit base types + Layout tmem_layout = composition(upcast<2>(tmem_restride), tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles its own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 16) // FP16 + { + // MMA_M x MMA_K: 128x32 atom / 64x32 atom / 32x32 atom + static_assert(N_MMA == 32); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, Shape < _16,_2>>, + Stride, Stride<_128,_8>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout, Shape < _16,_2>, _2>, + Stride, Stride<_128,_8>,_64>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout, Shape < _16,_2>, _4>, + Stride, Stride<_128,_8>,_32>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else if constexpr (sizeof_bits::value == 8) // I8|F8 + { + // MMA_M x MMA_K: 128x64 atom / 64x64 atom / 32x64 atom + static_assert(N_MMA == 64); + if constexpr (M_MMA == 128) { + Layout tmem_atom = Layout, + Stride< _1,_128>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 64) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _64>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else if constexpr (M_MMA == 32) { + Layout tmem_atom = Layout>, + Stride< _1, Stride<_128, _32>>>{}; + // Tile to MMA tiling + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Address transformations + Layout tmem_layout = composition(tmem_restride, tmem_logical_layout); + // Sparsity wrap, no sparse_ptr because tmem_ptr handles it's own subword addressing + return make_tensor(make_tmem_ptr>(), tmem_layout); + } + else { + static_assert(dependent_false, "Invalid M_MMA value"); + } + } + else { + static_assert(dependent_false, "Invalid ValueType"); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +template +struct tmem_sf_frg: tmem_frg_base +{ + // UMMA TMEM Allocator for Scale Factor A for Mxf4Nvf4 and Mxf8f6f4 instructions + // We expect a tensor that has the same layout as A matrix + // @tparam ValueType: data type of scaling factor + // Note that the StorageType is the same as ValueType, i.e., we always use a compact allocation + // @tparam SFVecSize: The number of values that is scaled by a single scaling factor. + // Valid values are (16, 32) + // @tparam N_SM: Number of SMs in UMMA instruction + // @param tmem_shape: An MMA partitioned shape where first mode encodes, A layout of the MMA instruction. + // Note that the shape doesn't match the actual allocation. size<0,1>(tmem_shape) will give us the number of + // elements in K-mode of MMA rather than the number of scaling factors. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int MMA_MN = decltype(size<0,0>(tmem_shape))::value; + constexpr int MMA_VS = decltype(size<0,1,0>(tmem_shape))::value; + constexpr int MMA_NSF = decltype(size<0,1,1>(tmem_shape))::value; + constexpr int R_MMA_K = decltype(rank(get<0,1>(tmem_shape)))::value; + constexpr int R = decltype(rank(tmem_shape))::value; + + // We expect an MMA-SF partitioned tensor + // ((MMA_MN, (VecSize, NSF)), num_MMA_MN, num_MMA_K, ...) + // where VecSize*NSF = MMA_K + static_assert(R >= 3, "Expected an MMA partitioned tensor"); // ((MMA), num_MMA_MN, num_MMA_K, ...) + static_assert(R_MMA_K == 2, "Expected an MMA-SF partitioned tensor"); // (VecSize, NSF) + using REP = _4; // Replication factor. Data is always replicated across subpartitions + constexpr int SUBPART_DPs = 32; // Number of DPs in a subpartition + + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + if constexpr (Is_SFA || (!Is_SFA && TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated4by1)) { + // SFA, 2x2 and 4x1 data path + // SFB, 4x1 data path + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, REP>, Shape, Int>>, + Stride, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + else { + // SFB, 2x2 datapath + static_assert(!Is_SFA and TmemAlloc == UMMA::TmemAllocMode::ScaleFactorDuplicated2by2); + static_assert(N_SM == 2, "Should be 2x2 Datapath"); + // 2x2 Datapth + auto tmem_atom = Layout < Shape< Shape< Shape, Int>, _2, _2>, Shape, Int>>, + Stride, _64, _32>, Stride< _0, _128>>>{}; + + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + auto final_tmem_layout = composition(tmem_restride, tmem_logical_layout); + return make_tensor(make_tmem_ptr(), final_tmem_layout); + } + } +}; + +// Make C/D Tmem fragment for weight-stationary MMAs +template +struct tmem_frg_ws : tmem_frg_base +{ + static_assert(sizeof_bits_v <= sizeof_bits_v, "TMEM MMA allocations require StorageType big enough for ValueType."); + + // UMMA TMEM Allocator + // Each UMMA expects a specific MxN layout of TMEM for accumulators + // and sometimes a specific MxK layout of TMEM for A-values. + // @tparam ValueType The value type of the TMEM Tensor to allocate. + // @tparam StorageType The storage type of the TMEM Tensor to allocate. + // "Sparse" allocations often allocate ValueType=half_t within StorageType=uint32_t. + // "Dense" allocations often allocate ValueType=half_t within StorageType=half_t. + // @tparam N_SM The number of SMs in this UMMA_XSM instruction. + // @tparam TmemAlloc UMMA-specific allocation modifier for special cases. + // Some UMMA instructions expect strange atoms or tilings of atoms. + // @param tmem_shape ((M_MMA_SM,N_MMA_SM),MMA_M,MMA_N,...) + // The post-MMA-partitioned shape of TMEM to allocate. + // Note for UMMA_2SM_128xNx16, that M_MMA_SM will be 64, for example. + template + CUTE_HOST_DEVICE constexpr static auto + make(TmemShape const& tmem_shape) + { + CUTE_STATIC_ASSERT_V(size(tmem_shape)*Int)>{} <= TMEM::MAX_CAPACITY_BITS{}, + "Requesting more TMEM than is available."); + CUTE_STATIC_ASSERT_V(rank<0>(tmem_shape) == Int<2>{}, "Expected post-partitioned shape ((M_MMA,N_MMA),...)."); + constexpr int R = decltype(rank(tmem_shape))::value; + constexpr int M_MMA = decltype(size<0,0>(tmem_shape))::value; + constexpr int N_MMA = decltype(size<0,1>(tmem_shape))::value; + + // It's convenient to use "virtual tensor memory addressing" + // with DP_STRIDE=1, COL_STRIDE=128 to define the tmem_atom, + // then convert to "logical tensor memory addressing" on return. + using COL_ADDR = C::value / sizeof_bits::value>; + Layout tmem_restride = Layout, + Stride, COL_ADDR>>{}; + + static_assert(N_SM == 1, "UMMA.WS expects N_SM == 1"); + + static_assert(M_MMA == 32 || M_MMA == 64 || M_MMA == 128, + "Weight stationary UMMA_1SM M-mode size should be 32 or 64 or 128."); + static_assert(N_MMA == 64 || N_MMA == 128 || N_MMA == 256, + "Dense weight stationary UMMA_1SM N-mode size should be 64 or 128 or 256."); + // Weight Stationary MMA config + if constexpr (M_MMA == 32) + { + // 1x4 datapath + Layout tmem_atom = Layout, _4>>, + Stride< _1, Stride< _128,_32>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 64) + { + // 2x2 datapath + Layout tmem_atom = Layout, _2>>, + Stride< _1, Stride< _128,_64>> + >{}; + constexpr int tile_stride = 1; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape), + compact_col_major(take<1,R>(tmem_shape), Int{}))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } else + if constexpr (M_MMA == 128) + { + // For M_MMA = 128, all datapaths are occupied. TmemAllocMode doesn't change the allocation. + // Full subpartitions layout atom: (M,N) -> tmem_addr + Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; + // This will tile in DPs first, then COLs + Layout tmem_logical_layout = tiled_product(tmem_atom, make_layout(take<1,R>(tmem_shape))); + // Restride for the DP/COL addressing and return + return make_tensor(make_tmem_ptr(), composition(tmem_restride, tmem_logical_layout)); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// Convenient aliases for common cases in the UMMA::ElementXFrg below +template +using tmem_frg_ws_1sm = tmem_frg_ws; + +} // end namespace UMMA + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_frg_ws::make(shape(tmem_shape)); + } +}; + + +// Customization point for creating a UMMA::tmem_frg Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_e_frg_ws::make(shape(tmem_shape)); + } +}; + +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Shape const& tmem_shape) { + return UMMA::tmem_sf_frg::make(shape(tmem_shape)); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_TS supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + static constexpr uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint32_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS supports 32bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_TS supports 32bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_TF32_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_SS_SCALED::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_TS_SCALED supports 16bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + constexpr static uint32_t ScalingFactor = ScaleC; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F16BF16_2x1SM_TS_SCALED::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } + + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, cute::integral_constant scaleC) const { + return {accumulate, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 4); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 32, "SM100_MMA_TF32_2x1SM_SS_SPARSE supports 32bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_TF32_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 2); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 16, "SM100_MMA_F16BF16_2x1SM_SS_SPARSE supports 16bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = tmem_e & 0x00000001; + tmem_e = tmem_e & ~0x00000001; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_F16BF16_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256bits, transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_TS supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint32_t id2 = 0; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, static_cast(id2), tmem_e); + + SM100_MMA_S8_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS supports 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat>(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_TS supports 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 256 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_S8_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v == cute::sizeof_bits_v && cute::sizeof_bits_v == 8, "SM100_MMA_S8_2x1SM_SS_SPARSE supports 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, UMMA::ScaleIn::One, UMMA::ScaleIn::One, c_sat, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_S8_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS supports types with leq 8bit types"); + static_assert(M == 64 || M == 128, "SM100_MMA_F8F6F4_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(((b_major == UMMA::Major::K) && ((N % 8 == 0) && (8 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 16 == 0) && (16 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 8 between 8 and 256 when B is K major. \ + SM100_MMA_F8F6F4_SS N-mode size should be a multiple of 16 between 16 and 256 when B is MN major."); + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS supports types with leq 8bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_SS_SPARSE supports types with leq 8bit types"); + + // Logical shape-K is always 512bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_TS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_1sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Logical shape-K is always 256 bits; transform to units of elements + static constexpr int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + + // SparseMma consume double mma-k bits + static constexpr int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits, cute::C, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant, + cute::integral_constant> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS supports types with leq 8bit types"); + static_assert(M == 128 || M == 256, "SM100_MMA_F8F6F4_2x1SM_SS M-mode size should be 64 or 128 for 1 CTA cluster MMA."); + static_assert(((b_major == UMMA::Major::K) && ((N % 16 == 0) && (16 <= N) && (N <= 256))) || + ((b_major == UMMA::Major::MN) && ((N % 32 == 0) && (32 <= N) && (N <= 256))), + "SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 16 between 16 and 256 when B is K major. \ + SM100_MMA_F8F6F4_2x1SM_SS N-mode size should be a multiple of 32 between 32 and 256 when B is MN major."); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions's K extent is always 256bits, convert to units of element + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc(); + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_TS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::tmem_frg_2sm; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + // Size of instructions' K extent is always 256 bits; convert to units of element + constexpr static int K = 32; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, c_sat>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t tmem_a = raw_pointer_cast(A.data()); + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc<>(traits.idesc_); + + SM100_MMA_F8F6F4_2x1SM_TS::fma(tmem_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc); + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_F8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 512 / cute::sizeof_bits::value; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptor idesc_ = UMMA::make_instr_desc< + a_type, b_type, c_type, M, N, a_major, b_major, a_neg, b_neg, UMMA::Saturate::False, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_same, cute::tuple>::value, + "Params must be set via .with()?"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc(traits.idesc_, id2, tmem_e); + + SM100_MMA_F8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(Tensor const& E, uint32_t id2 = 0) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + return {accumulate_, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS supports types with leq 8bit types"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 32; + constexpr static int SFVecSize = 32; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF8F6F4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + static_assert(sizeof(a_type) == 1); + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + static_assert(cute::sizeof_bits_v <= 8 && cute::sizeof_bits_v <= 8, "SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE supports types with leq 8bit types"); + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + // SparseMma consume double mma-k bits + constexpr static int K = 64; + constexpr static int SFVecSize = 64; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert(sizeof_bits_v <= sizeof_bits_v && + sizeof_bits_v <= sizeof_bits_v); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF8F6F4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF8F6F4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_SS supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 32 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 16), + "2x mode (VectorSize=32) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_SS_SPARSE supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::sparse_smem_desc; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4_2x1SM_SS supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 64; + constexpr static int SFVecSize = VS; + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_); + + SM100_MMA_MXF4_2x1SM_SS::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits> + with(UMMA::ScaleOut accumulate, Tensor const& SFA, Tensor const& SFB) const { + // Check sparse_ptr, check sparsity, check shape/layout? + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, idesc_}; + } +}; + +template +struct MMA_Traits, sparse_args...> +{ + using ValTypeD = c_type; + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + static_assert(cute::sizeof_bits_v == 4 && cute::sizeof_bits_v == 4, "SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE supports 4bit types"); + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 128; + constexpr static int SFVecSize = VS; + + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeA = UMMA::sparse_smem_desc; + // using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeE = UMMA::tmem_e_frg; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + static_assert((VS == 64 && ((is_same_v || is_same_v) && + (is_same_v || is_same_v)) + && is_same_v) + || (VS == 32), + "2x mode (VectorSize=64) only supports a_type and b_type=float_e2m1_t or cutlass::type_erased_dynamic_float4_t and sf_type=ue8m0_t"); + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using MMA_ScaleFactor = SM100_MMA_MXF4NVF4_SS_SPARSE 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + uint32_t tsfa_addr_ = 0; + uint32_t tsfb_addr_ = 0; + + // uint32_t tmem_e: Metadata tmem address. + cute::tuple sparse_args_; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg, true>(); + + template + CUTE_HOST_DEVICE constexpr friend + void + mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) + { + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_b = B[0]; + uint32_t tmem_c = raw_pointer_cast(D.data()); + + uint32_t tmem_e = get<0>(traits.sparse_args_); + uint16_t id2 = 0u; + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled(traits.idesc_, traits.tsfa_addr_, traits.tsfb_addr_, id2, tmem_e); + + SM100_MMA_MXF4NVF4_2x1SM_SS_SPARSE::fma(desc_a, desc_b, tmem_c, uint32_t(traits.accumulate_), idesc, traits.tsfa_addr_, traits.tsfb_addr_, tmem_e); + } + + // Construct an executable sparse MMA_traits with sp into set. + template + CUTE_HOST_DEVICE constexpr + MMA_Traits, uint32_t> + with(UMMA::ScaleOut accumulate, Tensor const& E, Tensor const& SFA, Tensor const& SFB) const { + uint32_t tmem_e_addr = raw_pointer_cast(E.data()); + uint32_t tmem_sfa_addr = raw_pointer_cast(SFA.data()); // Move to a CoupledTensor rather than a .with()? + uint32_t tmem_sfb_addr = raw_pointer_cast(SFB.data()); // Move to a CoupledTensor rather than a .with()? + return {accumulate, tmem_sfa_addr, tmem_sfb_addr, {tmem_e_addr}, idesc_}; + } +}; + +/** + * Specialization for a vectorized FMA per thread. + */ +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_2,_1,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = float; + using ValTypeB = float; + using ValTypeC = float; + + using Shape_MNK = Shape<_1,_2,_1>; + using ThrID = Layout<_1>; + + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +namespace SM103 { + // Common mma_unpack for all MMA_Ops in cute::SM103 +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& zA, + Tensor const& zB, + Tensor const& C) + { + auto [A, next_A, SFA] = unzip_tensor(zA); + auto [B, next_B, SFB] = unzip_tensor(zB); + + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected desc registers in MMA_Atom::call"); + static_assert(is_tmem::value, "Expected tmem in MMA_Atom::call"); + + uint64_t desc_a = A[0]; + uint64_t desc_next_a = next_A[0]; + uint64_t desc_b = B[0]; + uint64_t desc_next_b = next_B[0]; + + auto desc_a_temp = reinterpret_cast(desc_a); + auto desc_next_a_temp = reinterpret_cast(desc_next_a); + desc_a_temp.lbo_mode_ = 1; + desc_a_temp.leading_byte_offset_ = desc_next_a_temp.start_address_; + + auto desc_b_temp = reinterpret_cast(desc_b); + auto desc_next_b_temp = reinterpret_cast(desc_next_b); + desc_b_temp.lbo_mode_ = 1; + desc_b_temp.leading_byte_offset_ = desc_next_b_temp.start_address_; + + uint32_t tmem_c = raw_pointer_cast(D.data()); + UMMA::InstrDescriptorBlockScaled instr_desc = traits.idesc_; + instr_desc.k_size_ = 1; + auto tsfa_addr = raw_pointer_cast(SFA.data()); + auto tsfb_addr = raw_pointer_cast(SFB.data()); + + uint64_t idesc = UMMA::make_runtime_instr_desc_block_scaled<>(instr_desc, tsfa_addr, tsfb_addr); + // print("a: "); print(A); print("\n"); + // print("b: "); print(B); print("\n"); + + MMA_Op::fma(reinterpret_cast(desc_a_temp), reinterpret_cast(desc_b_temp), tmem_c, uint32_t(traits.accumulate_), idesc, tsfa_addr, tsfb_addr); + } +} // end namespace SM103 + + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_1sm; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_1>; + using ALayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride<_0,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS; + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = c_type; + using ValTypeA = a_type; + using ValTypeB = b_type; + using ValTypeC = c_type; + using ValTypeSFA = sf_type; + using ValTypeSFB = sf_type; + + // Logical shape-K is always 256bits, transform to units of elements + constexpr static int K = 96; + constexpr static int SFVecSize = VS; + + static_assert(a_major == UMMA::Major::K && b_major == UMMA::Major::K, "This MMA does not support transpose"); + + using FrgTypeA = UMMA::smem_desc; + using FrgTypeB = UMMA::smem_desc; + using FrgTypeC = UMMA::tmem_frg_2sm; + constexpr static UMMA::TmemAllocMode TmemAlloc = M == 128 ? + UMMA::TmemAllocMode::ScaleFactorDuplicated2by2 : UMMA::TmemAllocMode::ScaleFactorDuplicated4by1; + using FrgTypeSFA = UMMA::tmem_sf_frg; + using FrgTypeSFB = UMMA::tmem_sf_frg; + + using Shape_MNK = Shape,Int,Int>; + using ThrID = Layout<_2>; + using ALayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using BLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + using CLayout = Layout,Int>>, + Stride,Stride< _1,Int>>>; + + using MMA_ScaleFactor = SM100_MMA_MXF4_SS 64 ? M/2 : M), (round_up(N, 128)), VS, a_major, b_major, + a_neg, b_neg>; + + + // Accumulate or overwrite C. 1: read C, 0: ignore C [clear accumulators] + UMMA::ScaleOut accumulate_ = UMMA::ScaleOut::One; + + UMMA::InstrDescriptorBlockScaled idesc_ = UMMA::make_instr_desc_block_scaled< + a_type, b_type, c_type, sf_type, M, N, a_major, b_major, a_neg, b_neg>(); +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3399801cf667deb9589f2df6ce1c2aca32bb368 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace SM120::BLOCKSCALED { + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B_zipped, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [A, SFA] = unzip_tensor(A_zipped); + auto [B, SFB] = unzip_tensor(B_zipped); + + using Shape_MNK = typename MMA_Traits::Shape_MNK; + constexpr int SFVecSize = MMA_Traits::SFVecSize; + + // Assert logical size + CUTE_STATIC_ASSERT_V(size(SFA) == size<2>(Shape_MNK{})); + CUTE_STATIC_ASSERT_V(size(SFB) == size<2>(Shape_MNK{})); + + // Assert physical size + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFA))){} == size<2>(Shape_MNK{}) / SFVecSize); + CUTE_STATIC_ASSERT(decltype(cosize(layout(SFB))){} == size<2>(Shape_MNK{}) / SFVecSize); + + Tensor rA = recast(A); + Tensor rB = recast(B); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + Tensor rSFA = recast(filter_zeros(SFA)); + Tensor rSFB = recast(filter_zeros(SFB)); + + CUTE_STATIC_ASSERT_V(size(rSFA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rSFB) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} +} // namespace SM120::BLOCKSCALED + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA F8F6F4 16x8x32 TN +template +struct MMA_Traits> + : MMA_Traits +{ + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x64 TN +template +struct MMA_Traits> +{ + // The MMA accepts 4-bit inputs regardless of the types for A and B + using ValTypeA = uint4_t; + using ValTypeB = uint4_t; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16,_8,_64>; + using ThrID = Layout<_32>; + + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (M16,K64) + using BLayout = Layout,Shape <_8, _2>>, + Stride,Stride<_8,_256>>>; + // (T32,V64) -> (M16,K64) + using SFALayout = Layout,_64>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V64) -> (N8,K64) + using SFBLayout = Layout,_64>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM80_16x8_Row; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA MXF8F6F4 16x8x32 TN +template +struct MMA_Traits> +{ + using UnderlyingTraits = MMA_Traits>; + + // The MMA accepts 8-bit inputs regardless of the types for A and B + using ValTypeA = typename UnderlyingTraits::ValTypeA; + using ValTypeB = typename UnderlyingTraits::ValTypeB; + + using ValTypeD = typename UnderlyingTraits::ValTypeD; + using ValTypeC = typename UnderlyingTraits::ValTypeC; + + using Shape_MNK = typename UnderlyingTraits::Shape_MNK; + using ThrID = typename UnderlyingTraits::ThrID; + + using ALayout = typename UnderlyingTraits::ALayout; + using BLayout = typename UnderlyingTraits::BLayout; + using CLayout = typename UnderlyingTraits::CLayout; + + // Scaling factor + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + // (T32,V32) -> (M16,K32) + using SFALayout = Layout,_32>, // Effectively 16 threads due to the 2:0 mode + Stride,_16>>; + // (T32,V32) -> (N8,K32) + using SFBLayout = Layout,_32>, // Effectively 8 threads due to the 4:0 mode + Stride, _8>>; +}; + +// Transform if needed +template +CUTLASS_DEVICE void +fp4_shift_A(MMA_Op const& op, Tensor&& tensor) { +} +template +CUTLASS_DEVICE void +fp4_shift_B(MMA_Op const& op, Tensor&& tensor) { +} + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_16x8x32_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +namespace SM120::BLOCKSCALED { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120::BLOCKSCALED::SM120_16x8x32_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b62c576dfb6e0f67ed76c8a95818a7000ff43e52 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm120_sparse.hpp @@ -0,0 +1,326 @@ +/*************************************************************************************************** + * Copyright (c) 2025 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM120_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +namespace SM120::BLOCKSCALED::SPARSE +{ + +// Unpack explode/mma call with sparse and block scalaring inputs. +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + using SFARegisters = typename MMAOp::SFARegisters; + using SFBRegisters = typename MMAOp::SFBRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + using RegTypeSFA = typename remove_extent::type; + using RegTypeSFB = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + constexpr int RegNumSFA = extent::value; + constexpr int RegNumSFB = extent::value; + + auto [tA, tSFA, tE] = unzip_tensor(A); + auto [tB, tSFB ] = unzip_tensor(B); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(tB); + Tensor rD = recast(D); + Tensor rC = recast(C); + Tensor rSFA = recast(tSFA); + Tensor rSFB = recast(tSFB); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFA)) == Int{}); + CUTE_STATIC_ASSERT_V(size(filter_zeros(rSFB)) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + rSFA, make_int_sequence{}, + rSFB, make_int_sequence{}); +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + + +namespace SM120::SPARSE +{ + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const&, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [tA, tE] = unzip_tensor(A); + Tensor rA = recast(tA); + Tensor rE = recast(tE); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}); +} + +} // end namespace SM120::SPARSE + +// sparse F8F6F4 without block-scaling +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,K64) + using ALayout = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_16,_8,_512>>>; + // (T32,V16) -> (N8,K64) + using BLayout = Layout,Shape <_4, _4>>, + Stride,Stride<_8,_128>>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + + // (T32, V32) -> (M16, K64) + using ELayout = Layout, _32>, + Stride,_16>>; +}; + +// sparse MXF8F6F4 with block-scaling. +template +struct MMA_Traits> + : MMA_Traits> +{ + using ValTypeA = sparse_elem<2, a_type>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using FrgTypeA = sparse_elem<2, uint8_t>; + using FrgTypeE = sparse_elem<8, uint8_t>; + + using ValTypeD = c_type; + using ValTypeC = c_type; + + using ValTypeSF = sf_type; + constexpr static int SFVecSize = VS; + + using UnderlyingSFTraits = MMA_Traits>; + using SFALayout = typename UnderlyingSFTraits::SFALayout; + using SFBLayout = typename UnderlyingSFTraits::SFBLayout; +}; + +template +struct MMA_Traits> +{ + using ValTypeA = sparse_elem<4, uint8_t>; + using ValTypeE = sparse_elem<16, uint8_t>; + using ValTypeB = uint4_t; + using FrgTypeA = sparse_elem<4, uint8_t>; + using FrgTypeE = sparse_elem<16, uint8_t>; + + using ValTypeC = c_type; + using ValTypeD = c_type; + + using ValTypeSF = sf_type; + + constexpr static int SFVecSize = VS; + + using Shape_MNK = Shape<_16, _8, _128>; + using ThrID = Layout<_32>; + // (T32,V64) -> (M16,K128) + using ALayout = Layout,Shape <_16,_2, _2>>, + Stride,Stride<_16,_8,_1024>>>; + // (T32,V32) -> (N8,K128) + using BLayout = Layout,Shape <_8, _4>>, + Stride,Stride<_8,_256>>>; + // (T32,V128) -> (M16,K128) + using SFALayout = Layout,_128>, + Stride, _16>>; + // (T32,V128) -> (N8,K128) + using SFBLayout = Layout,_128>, + Stride, _8>>; + // (T32,V4) -> (M16,N8) + using CLayout = SM120_16x8_Row; + // (T32, V64) -> (M16, K128) + using ELayout = Layout, Shape< _64>>, + Stride,Stride<_16>>>; +}; + +namespace SM120::SPARSE { + +// For SM120 MMA F8F6F4 input fp4, the operand A/B are load from ld.matrix. +// ld.matrix b4x16_p64 places FP4 data at the first four bits in each +// eight-bit container, whereas MMA F8F6F4 expects the four-bit data to be in +// the middle of the eight-bit container. Thus, e2m1 operands being fed +// to MMA F8F6F4 must be shifted left by two bits. +// 0b0000ABCD --> 0b00ABCD00 +// NOTE: Same transformation is NOT needed for FP6 and FP8. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::SPARSE + +namespace SM120::BLOCKSCALED::SPARSE { + +// Template function with scale factor needs to enmuerate types one by one, as template +// arguments contatins two variadic lists, which cannot be deduced in one shot. +template +CUTLASS_DEVICE void +fp4_shift_A(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeA = typename remove_extent::ARegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeA& v){ return v << 2; }); + } +} +template +CUTLASS_DEVICE void +fp4_shift_B(SM120_SPARSE_16x8x64_TN_VS const&, Tensor&& tensor) { + using RegisterTypeB = typename remove_extent::BRegisters>::type; + if constexpr (cute::is_same_v) { + cute::transform(recast(tensor), [](RegisterTypeB& v){ return v << 2; }); + } +} + +} // end namespace SM120::BLOCKSCALED::SPARSE + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b12903b987a6113e43a71ec197915a7d814649f --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm61.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_4>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int16_t; + using ValTypeB = int16_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_2>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0b5b5300b0828067fb9d66c4b6e4640bb54d2bbd --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm70.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +namespace { + +// Logical thread id to thread idx (quadpair) +using SM70_QuadPair = Layout, + Stride<_1,_16>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Row = Layout, + Stride<_1,_8>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Col = Layout,_4>, + Stride,_1>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_16b = Layout, + Stride<_1,_8>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, + Stride,Stride<_8,_2,_32>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d60c65fe59fb73a248874dfc23232c49d23ad4ec --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm75.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = Layout,_2>, + Stride,_8>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,_4>, + Stride,_8>>; + using BLayout = Layout,_4>, + Stride,_8>>; + using CLayout = Layout,_2>, + Stride,_8>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f7d5d2fc3005d1e12fdd066cf0d80709de0e5234 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm80.hpp @@ -0,0 +1,690 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V1) -> (M8,N8) +using SM80_8x4 = Layout,_1>, + Stride,_0>>; +// (T32,V2) -> (M8,N8) +using SM80_8x8_Row = Layout,_2>, + Stride,_8>>; +// (T32,V4) -> (M8,N16) +using SM80_8x16_Row = Layout,_4>, + Stride,_8>>; +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = SM80_16x8_Row; + using BLayout = SM80_8x8_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_16,_8,_128>>>; + using BLayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = SM80_8x4; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x4; + using BLayout = SM80_8x4; + using CLayout = SM80_8x8_Row; +}; + +// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x16_Row; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout, Shape <_4, _2>>, + Stride, Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_256>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2,_2>>, + Stride,Stride<_16,_8,_2048>>>; + using BLayout = Layout,Shape<_32,_2>>, + Stride,Stride< _8,_1024>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 & b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,_32>, + Stride,_8>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2>>, + Stride,Stride>>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f56aab4cb1d642ddadf9f80aac487409aaa23b8b --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm89.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// + +// +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +template <> +struct MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout,Shape <_4, _2>>, + Stride,Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e4m3_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e5m2_t; + using ValTypeC = cutlass::half_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = cutlass::half_t; + using ValTypeA = cutlass::float_e5m2_t; + using ValTypeB = cutlass::float_e4m3_t; + using ValTypeC = cutlass::half_t; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0467dec3e8cb688f68d86468e005ed2c4bd1216a --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_F64F64F64F64_TN = SM90::MMA_16x8x4_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = Layout,_1>, + Stride,_0>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x8_F64F64F64F64_TN = SM90::MMA_16x8x8_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x16_F64F64F64F64_TN = SM90::MMA_16x8x16_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _4>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _4>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////////// +//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_C64C64C64C64_TN = SM90::MMA_16x8x4_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x8_C64C64C64C64_TN = SM90::MMA_16x8x8_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x16_C64C64C64C64_TN = SM90::MMA_16x8x16_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e1c3bb40349faecf4ec2f8c36f0e90b7f26d23d7 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -0,0 +1,8987 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::smem_ptr_flag +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90_64x8x16_F16F16F16_SS, etc +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major GMMA layouts in units of bits +using Layout_MN_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _128>>>; +using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; +using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; +using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; + +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// M|N-major layouts in units of Type +template +using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); +template +using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); +template +using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); +template +using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +// With GMMA::Major param +template +using Layout_INTER_Atom = typename conditional, + Layout_K_INTER_Atom>::type; +template +using Layout_SW32_Atom = typename conditional, + Layout_K_SW32_Atom>::type; +template +using Layout_SW64_Atom = typename conditional, + Layout_K_SW64_Atom>::type; +template +using Layout_SW128_Atom = typename conditional, + Layout_K_SW128_Atom>::type; + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + static_assert(M == 4, "Unsupported layout swizzle"); + static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + case 0: return LayoutType::INTERLEAVE; + case 1: return LayoutType::B32; + case 2: return LayoutType::B64; + case 3: return LayoutType::B128; + } + return LayoutType::INTERLEAVE; // ERROR +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for GMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_gmma_desc // +* /////////////////////////////// +* Each GmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to GMMA shape +* k : integer in [1,32] corresponding to GMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_gmma_desc // +* ////////////////////////////// +* Each GmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +make_gmma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + GmmaDescriptor desc; + + // Layout type + constexpr LayoutType LAYOUT_TYPE = layout_type(u128_tensor); + desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.bitfield.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.bitfield.base_offset_ = base_offset; + + // LayoutType meta + constexpr int W = LAYOUT_TYPE == LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == LayoutType::B32 ? 2 : + LAYOUT_TYPE == LayoutType::B64 ? 4 : + LAYOUT_TYPE == LayoutType::B128 ? 8 : -1; + + if constexpr (MajorMode == Major::MN) + { + /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + */ + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{} || // A and B in dense MMA + size<1>(u128_tensor) == Int<(128 / cute::sizeof_bits::value)>{} || // A in sparse MMA + size<1>(u128_tensor) == Int<(512 / cute::sizeof_bits::value)>{}, // B in sparse MMA + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); + + // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,_1>,Layout,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_MN Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = W; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_11 : stride_01; + } + else if constexpr (MajorMode == Major::K) + { + /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + */ + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{} || size<1>(u128_tensor) == Int<4>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); + + // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), Tile,Layout<_2,_1>>{}); + + // Check profile of canonical + CUTE_STATIC_ASSERT_V(congruent(canonical_layout, Shape,Shape<_1,_1>>{}), "Not a canonical GMMA_K Layout: Expected profile failure."); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = W; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = stride_01; + desc.bitfield.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); + } + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = GmmaDescriptor; + using element_type = GmmaDescriptor; + using value_type = GmmaDescriptor; + + GmmaDescriptor desc_; + + // Dereference returns the GmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new GmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + // Use 32bit calculation rather than 64 bit calculation as we only update the part of desc + GmmaDescriptor ret; + ret.reg32_[0] = desc_.reg32_[0] + uint32_t(offset); + ret.reg32_[1] = desc_.reg32_[1]; + return { ret }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType in mma_unpack +template +CUTE_HOST_DEVICE constexpr +DescriptorIterator +recast_ptr(DescriptorIterator const& iter) { + static_assert(is_same::value, "Can only cast GmmaDescriptorIterator to uint64_t."); + return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator) { + printf("GMMA::DescriptorIterator"); +} + +// The GMMA Traits below have custom fragment type flags for their smem desc tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a GMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// +// Specialized mma_unpack implementation for SM90 GMMA instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +// Accumulator layouts +template +using CLayout_64xN = Layout,Shape < _2,_2,Int>>, + Stride,Stride<_64,_8, _512>>>; + +using CLayout_64x8 = CLayout_64xN< 8>; +using CLayout_64x16 = CLayout_64xN< 16>; +using CLayout_64x32 = CLayout_64xN< 32>; +using CLayout_64x64 = CLayout_64xN< 64>; +using CLayout_64x96 = CLayout_64xN< 96>; +using CLayout_64x128 = CLayout_64xN<128>; +using CLayout_64x192 = CLayout_64xN<192>; +using CLayout_64x256 = CLayout_64xN<256>; + +// Register source layout for 32-bit value types +using ALayout_64x8 = Layout,Shape < _2, _2>>, + Stride,Stride< _8,_256>>>; + +// Register source layout for 16-bit (sparse 32-bit) value types +using ALayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; + +// Register source layout for 8-bit (sparse 16-bit) value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Register source layout for sparse 8-bit value types +using ALayout_64x64 = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_64,_8,_2048>>>; + +// Shared memory source layouts for any value type +template +using ABLayout = Layout,Int>>, + Stride< _0,Stride< _1,Int>>>; + +} // end namespace SM90::GMMA + +using namespace SM90; + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_SS = SM90::GMMA::MMA_64x8x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_RS = SM90::GMMA::MMA_64x8x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_SS = SM90::GMMA::MMA_64x16x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_RS = SM90::GMMA::MMA_64x16x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_SS = SM90::GMMA::MMA_64x32x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_RS = SM90::GMMA::MMA_64x32x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3cab34d257710e2b335e526558c1401826b27492 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_ext.hpp @@ -0,0 +1,20116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +namespace SM90::GMMA { + +using CLayout_64x24 = CLayout_64xN< 24>; +using CLayout_64x40 = CLayout_64xN< 40>; +using CLayout_64x48 = CLayout_64xN< 48>; +using CLayout_64x56 = CLayout_64xN< 56>; +using CLayout_64x72 = CLayout_64xN< 72>; +using CLayout_64x80 = CLayout_64xN< 80>; +using CLayout_64x88 = CLayout_64xN< 88>; +using CLayout_64x104 = CLayout_64xN<104>; +using CLayout_64x112 = CLayout_64xN<112>; +using CLayout_64x120 = CLayout_64xN<120>; +using CLayout_64x136 = CLayout_64xN<136>; +using CLayout_64x144 = CLayout_64xN<144>; +using CLayout_64x152 = CLayout_64xN<152>; +using CLayout_64x160 = CLayout_64xN<160>; +using CLayout_64x168 = CLayout_64xN<168>; +using CLayout_64x176 = CLayout_64xN<176>; +using CLayout_64x184 = CLayout_64xN<184>; +using CLayout_64x200 = CLayout_64xN<200>; +using CLayout_64x208 = CLayout_64xN<208>; +using CLayout_64x216 = CLayout_64xN<216>; +using CLayout_64x224 = CLayout_64xN<224>; +using CLayout_64x232 = CLayout_64xN<232>; +using CLayout_64x240 = CLayout_64xN<240>; +using CLayout_64x248 = CLayout_64xN<248>; + +} + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_SS = SM90::GMMA::MMA_64x24x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_RS = SM90::GMMA::MMA_64x24x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_SS = SM90::GMMA::MMA_64x40x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_RS = SM90::GMMA::MMA_64x40x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_SS = SM90::GMMA::MMA_64x56x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_RS = SM90::GMMA::MMA_64x56x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_SS = SM90::GMMA::MMA_64x72x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_RS = SM90::GMMA::MMA_64x72x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_SS = SM90::GMMA::MMA_64x88x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_RS = SM90::GMMA::MMA_64x88x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_SS = SM90::GMMA::MMA_64x104x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_RS = SM90::GMMA::MMA_64x104x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_SS = SM90::GMMA::MMA_64x120x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_RS = SM90::GMMA::MMA_64x120x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_SS = SM90::GMMA::MMA_64x136x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_RS = SM90::GMMA::MMA_64x136x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_SS = SM90::GMMA::MMA_64x152x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_RS = SM90::GMMA::MMA_64x152x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_SS = SM90::GMMA::MMA_64x168x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_RS = SM90::GMMA::MMA_64x168x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_SS = SM90::GMMA::MMA_64x184x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_RS = SM90::GMMA::MMA_64x184x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_SS = SM90::GMMA::MMA_64x200x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_RS = SM90::GMMA::MMA_64x200x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_SS = SM90::GMMA::MMA_64x216x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_RS = SM90::GMMA::MMA_64x216x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_SS = SM90::GMMA::MMA_64x232x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_RS = SM90::GMMA::MMA_64x232x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_SS = SM90::GMMA::MMA_64x248x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_RS = SM90::GMMA::MMA_64x248x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_SS = SM90::GMMA::MMA_64x24x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_RS = SM90::GMMA::MMA_64x24x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_RS = SM90::GMMA::MMA_64x40x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_SS = SM90::GMMA::MMA_64x56x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_RS = SM90::GMMA::MMA_64x56x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_SS = SM90::GMMA::MMA_64x72x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_RS = SM90::GMMA::MMA_64x72x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_SS = SM90::GMMA::MMA_64x88x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_RS = SM90::GMMA::MMA_64x88x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_SS = SM90::GMMA::MMA_64x104x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_RS = SM90::GMMA::MMA_64x104x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_SS = SM90::GMMA::MMA_64x120x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_RS = SM90::GMMA::MMA_64x120x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_SS = SM90::GMMA::MMA_64x136x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_RS = SM90::GMMA::MMA_64x136x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_SS = SM90::GMMA::MMA_64x152x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_RS = SM90::GMMA::MMA_64x152x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_SS = SM90::GMMA::MMA_64x168x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_RS = SM90::GMMA::MMA_64x168x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_SS = SM90::GMMA::MMA_64x184x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_RS = SM90::GMMA::MMA_64x184x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_SS = SM90::GMMA::MMA_64x200x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_RS = SM90::GMMA::MMA_64x200x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_SS = SM90::GMMA::MMA_64x216x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_RS = SM90::GMMA::MMA_64x216x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_SS = SM90::GMMA::MMA_64x232x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_RS = SM90::GMMA::MMA_64x232x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_SS = SM90::GMMA::MMA_64x248x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_RS = SM90::GMMA::MMA_64x248x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13ff07c89f47a15bb65f57be2ad999df1f6568c9 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -0,0 +1,7738 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90::SPARSE::GMMA_64x8x32_F16F16F16_SS, etc +#include // cute::GMMA::Layout_* +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major layouts in units of Type and sparsity factor S +template +using Layout_MN_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_INTER_Atom{}.layout_b()))>; +template +using Layout_MN_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW32_Atom{}.layout_b()))>; +template +using Layout_MN_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW64_Atom{}.layout_b()))>; +template +using Layout_MN_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_Atom{}.layout_b()))>; + +// K-major layouts in units of Type and sparsity factor S +template +using Layout_K_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_INTER_Atom{}.layout_b()))>; +template +using Layout_K_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW32_Atom{}.layout_b()))>; +template +using Layout_K_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW64_Atom{}.layout_b()))>; +template +using Layout_K_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW128_Atom{}.layout_b()))>; + +// With GMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a cute::GMMAsparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as cute::GMMAsmem_desc above, plus additional static checks. + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// Metadata layouts +using ELayout_64x64 = Layout, Shape <_32>>, + Stride, Stride<_64>>>; + +using ELayout_64x32 = Layout, Shape <_16,_2>>, + Stride, Stride<_64,_8>>>; + +using ELayout_64x16 = Layout, Shape < _8,_2>>, + Stride, Stride<_64,_8>>>; + +} // namespace SM90::GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [A, E] = unzip_tensor(A_zipped); + Tensor rA = recast(A); + Tensor rE = recast(E); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + static_assert(is_same::value, "GMMA DRegisters must have void type."); + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fc28b8a97a77e5e082effd1cc994f25465b0a3e6 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,17335 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp new file mode 100644 index 0000000000000000000000000000000000000000..b15de6e266e206e37bc29fcdcfbbc2e74decce16 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/atom/partitioner.hpp @@ -0,0 +1,110 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(type_traits) +#else +#include +#endif + +#include +#include + +namespace cute { + +// +// A generic tiling of thread-value layouts +// + +template coord [Need not be 2D...] + class Tiler_MN_> // coord space +struct TV_Tiler +{ + using Tiler_MN = Tiler_MN_; + using TiledLayout_TV = Layout_TV_; + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,FrgV),(RestM,RestN,...)) + // where + // ThrV: The threads local to a tile. + // FrgV: The values local to a tile. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + apply(Tensor&& tensor) + { + // If Layout_TV and Tiler_MN were composable in general, then this won't be needed! + + // ((thr_id,val_id),(RestM,RestN,...)) + return zipped_divide(tensor, Tiler_MN{}).compose(TiledLayout_TV{}, _); + } + + template + struct TV_Partitioner + { + SliceCoord coord_; + + template + CUTE_HOST_DEVICE + auto + partition(TargetTensor&& target) { + Tensor thr_tensor = make_tensor(static_cast(target).data(), apply(target.layout())); + return thr_tensor(coord_, repeat>(_)); + } + }; + + template + CUTE_HOST_DEVICE static + auto + get_slice(SliceCoord const& coord) + { + return TV_Partitioner{coord}; + } +}; + +template +CUTE_HOST_DEVICE +auto +make_tiler_impl(Layout_TV const&, + Tiler_MN const&) +{ + return TV_Tiler{}; +} + +} diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp new file mode 100644 index 0000000000000000000000000000000000000000..538472c274baa5ccc7c59ac3d36b52791f061fdf --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/config.hpp @@ -0,0 +1,159 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +# define CUTE_DEVICE __forceinline__ __device__ +# define CUTE_HOST __forceinline__ __host__ +#else +# define CUTE_HOST_DEVICE inline +# define CUTE_DEVICE inline +# define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDACC_RTC__) +# define CUTE_HOST_RTC CUTE_HOST_DEVICE +#else +# define CUTE_HOST_RTC CUTE_HOST +#endif + +#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) +# define CUTE_UNROLL #pragma unroll +# define CUTE_NO_UNROLL #pragma unroll 1 +#elif defined(__CUDACC_RTC__) || defined(__clang__) +# define CUTE_UNROLL _Pragma("unroll") +# define CUTE_NO_UNROLL _Pragma("unroll 1") +#else +# define CUTE_UNROLL +# define CUTE_NO_UNROLL +#endif // CUTE_UNROLL + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_INLINE_CONSTANT static const __device__ +#else +# define CUTE_INLINE_CONSTANT static constexpr +#endif + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +// Some versions of GCC < 11 have trouble deducing that a +// function with "auto" return type and all of its returns in an "if +// constexpr ... else" statement must actually return. Thus, GCC +// emits spurious "missing return statement" build warnings. +// Developers can suppress these warnings by using the +// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. +// It's harmless to use the macro for other GCC versions or other +// compilers, but it has no effect. +#if ! defined(CUTE_GCC_UNREACHABLE) +# if defined(__GNUC__) +# define CUTE_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTE_GCC_UNREACHABLE +# endif +#endif + +#if defined(_MSC_VER) +// Provides support for alternative operators 'and', 'or', and 'not' +# include +#endif // _MSC_VER + +#if defined(__CUDACC_RTC__) +# define CUTE_STL_NAMESPACE cuda::std +# define CUTE_STL_NAMESPACE_IS_CUDA_STD +#else +# define CUTE_STL_NAMESPACE std +#endif + +// +// Assertion helpers +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +#define CUTE_STATIC_V(x) decltype(x)::value + +#define CUTE_STATIC_ASSERT static_assert +#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) + +// Fail and print a message. Typically used for notification of a compiler misconfiguration. +#if defined(__CUDA_ARCH__) +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt() +#else +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x) +#endif + +// +// IO +// + +#if !defined(__CUDACC_RTC__) +# include +# include +# include +#endif + +// +// Type +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +// +// Debugging utilities +// + +#include diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f285004c0c55c9cedc67dbfa1c83b202b3c10f41 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/alignment.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// Test if a pointer is aligned to N bytes +template +CUTE_HOST_DEVICE constexpr +bool +is_byte_aligned(void const* const ptr) +{ + static_assert(has_single_bit(N), "N must be a power of 2 in alignment check"); + return (reinterpret_cast(ptr) & (N-1)) == 0; +} + +#if defined(__CUDACC__) +# define CUTE_ALIGNAS(n) __align__(n) +#else +# define CUTE_ALIGNAS(n) alignas(n) +#endif + +template +struct aligned_struct {}; + +template struct CUTE_ALIGNAS( 1) aligned_struct< 1, Child> {}; +template struct CUTE_ALIGNAS( 2) aligned_struct< 2, Child> {}; +template struct CUTE_ALIGNAS( 4) aligned_struct< 4, Child> {}; +template struct CUTE_ALIGNAS( 8) aligned_struct< 8, Child> {}; +template struct CUTE_ALIGNAS( 16) aligned_struct< 16, Child> {}; +template struct CUTE_ALIGNAS( 32) aligned_struct< 32, Child> {}; +template struct CUTE_ALIGNAS( 64) aligned_struct< 64, Child> {}; +template struct CUTE_ALIGNAS(128) aligned_struct<128, Child> {}; +template struct CUTE_ALIGNAS(256) aligned_struct<256, Child> {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31606e45b9c1a6520596ccdb778f6a962b561179 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array.hpp @@ -0,0 +1,470 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + { + using CUTE_STL_NAMESPACE::swap; + for (size_type i = 0; i < size(); ++i) { + swap((*this)[i], other[i]); + } + } + + element_type __elems_[N]; +}; + + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using const_iterator = const_pointer; + using iterator = pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return true; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + {} + + CUTE_HOST_DEVICE constexpr + void clear() + {} + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + {} +}; + +template +CUTE_HOST_DEVICE constexpr +bool operator==(array const& lhs, array const& rhs) +{ + for (size_t i = 0; i < N; ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +template +CUTE_HOST_DEVICE constexpr +void clear(array& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array& a, T const& value) +{ + a.fill(value); +} + +template +CUTE_HOST_DEVICE constexpr +void swap(array& a, array& b) +{ + a.swap(b); +} + +/// @return A cute::array of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +cute::array reverse(cute::array const& t) +{ + if constexpr (N == 0u) { + return t; + } else { + cute::array t_r{}; + for (size_t k = 0; k < N; ++k) { + t_r[k] = t[N - k - 1]; + } + return t_r; + } +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6491f72db1b8a98c7d8e445ab53cb9ec10ee4640 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_aligned.hpp @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_ALIGNAS +#include // cute::array + +namespace cute +{ + +template +struct CUTE_ALIGNAS(Alignment) array_aligned : cute::array {}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp new file mode 100644 index 0000000000000000000000000000000000000000..be6410b318be72f74534b4344c3810cd1ccfb389 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/array_subbyte.hpp @@ -0,0 +1,640 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates subbyte trivial types + in a packed storage. +*/ + +#pragma once + +#include + +#include +#include + +namespace cute +{ +// +// Underlying subbyte storage type +// +template +using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v <= 8), uint8_t, + conditional_t<(cute::sizeof_bits_v <= 16), uint16_t, + conditional_t<(cute::sizeof_bits_v <= 32), uint32_t, + conditional_t<(cute::sizeof_bits_v <= 64), uint64_t, + conditional_t<(cute::sizeof_bits_v <= 128), uint128_t, + T>>>>>; + +template struct subbyte_iterator; +template struct swizzle_ptr; + +// +// subbyte_reference +// Proxy object for sub-byte element references +// +template +struct subbyte_reference +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v - sizeof_bits_v)); + // Flag for fast branching on straddled elements + static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); + + friend struct subbyte_iterator; + + // Pointer to storage element + storage_type* ptr_ = nullptr; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_ = 0; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) {} + +public: + + // Copy Ctor + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + // Copy Assignment + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + // Assignment + template + CUTE_HOST_DEVICE constexpr + enable_if_t, subbyte_reference&> operator=(value_type x) + { + static_assert(is_same_v, "Do not specify template arguments!"); + storage_type item = (reinterpret_cast(x) & BitMask); + + // Update the current storage element + storage_type bit_mask_0 = storage_type(BitMask << idx_); + ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Update the next storage element + ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); + } + + return *this; + } + + // Comparison of referenced values + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } + + // Value + CUTE_HOST_DEVICE + value_type get() const + { + if constexpr (is_same_v) { // Extract to bool -- potentially faster impl + return bool((*ptr_) & (BitMask << idx_)); + } else { // Extract to value_type + // Extract from the current storage element + auto item = storage_type((ptr_[0] >> idx_) & BitMask); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Extract from the next storage element + item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); + } + + return reinterpret_cast(item); + } + } + + // Extract to type value_type + CUTE_HOST_DEVICE constexpr + operator value_type() const { + return get(); + } + + // Address + CUTE_HOST_DEVICE + subbyte_iterator operator&() const { + return {ptr_, idx_}; + } +}; + +template +CUTE_HOST_DEVICE +void +print(subbyte_reference ref) { + cute::print(ref.get()); +} + +template +CUTE_HOST_DEVICE +void +pretty_print(subbyte_reference ref) { + cute::pretty_print(ref.get()); +} + +// +// subbyte_iterator +// Random-access iterator over subbyte references +// +template +struct subbyte_iterator +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + // Reference proxy type + using reference = subbyte_reference; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + template friend struct swizzle_ptr; + template friend CUTE_HOST_DEVICE constexpr U* raw_pointer_cast(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE constexpr auto recast_ptr(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE void print(subbyte_iterator const&); + + // Pointer to storage element + storage_type* ptr_; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_; + +public: + + // Default Ctor + CUTE_HOST_DEVICE constexpr + subbyte_iterator() : ptr_{nullptr}, idx_{0} {}; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator+=(uint64_t k) { + k = sizeof_bits_v * k + idx_; + ptr_ += k / sizeof_bits_v; + idx_ = k % sizeof_bits_v; + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_, idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](uint64_t k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator++() { + idx_ += sizeof_bits_v; + if (idx_ >= sizeof_bits_v) { + ++ptr_; + idx_ -= sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator--() { + if (idx_ >= sizeof_bits_v) { + idx_ -= sizeof_bits_v; + } else { + --ptr_; + idx_ += sizeof_bits_v - sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; + } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); + } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } +}; + +// Conversion to raw pointer with loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); +} + +// Conversion to NewT_ with possible loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; +} + +// Dynamic pointers have unknown static alignment +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(subbyte_iterator const& x) { + return {}; +} + +template +CUTE_HOST_DEVICE void +print(subbyte_iterator const& x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); +} + +template +CUTE_HOST_DEVICE void +print(subbyte_reference const& x) { + print(x.get()); +} + +// +// array_subbyte +// Statically sized array for non-byte-aligned data types +// +template +struct array_subbyte +{ + using element_type = T; + using value_type = remove_cv_t; + using pointer = element_type*; + using const_pointer = element_type const*; + + using size_type = size_t; + using difference_type = ptrdiff_t; + + // + // References + // + using reference = subbyte_reference; + using const_reference = subbyte_reference; + + // + // Iterators + // + using iterator = subbyte_iterator; + using const_iterator = subbyte_iterator; + + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + +private: + + // Number of storage elements, ceil_div + static constexpr size_type StorageElements = (N * sizeof_bits_v + sizeof_bits_v - 1) / sizeof_bits_v; + + // Internal storage + storage_type storage[StorageElements]; + +public: + + CUTE_HOST_DEVICE constexpr + size_type size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const { + return !N; + } + + // Efficient clear method + CUTE_HOST_DEVICE constexpr + void clear() { + CUTE_UNROLL + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = storage_type(0); + } + } + + CUTE_HOST_DEVICE constexpr + void fill(T const& value) { + CUTE_UNROLL + for (size_type i = 0; i < N; ++i) { + at(i) = value; + } + } + + CUTE_HOST_DEVICE constexpr + reference at(size_type pos) { + return iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference at(size_type pos) const { + return const_iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + reference front() { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + reference back() { + return at(N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const { + return at(N-1); + } + + // In analogy to std::vector::data(), these functions are deleted to prevent bugs. + // Instead, prefer + // auto* data = raw_pointer_cast(my_subbyte_array.begin()); + // where the type of auto* is implementation-defined and + // with the knowledge that [data, data + my_subbyte_array.size()) may not be a valid range. + CUTE_HOST_DEVICE constexpr + pointer data() = delete; + + CUTE_HOST_DEVICE constexpr + const_pointer data() const = delete; + + CUTE_HOST_DEVICE constexpr + iterator begin() { + return iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const { + return const_iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() { + return iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const { + return const_iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const { + return end(); + } + + // + // Comparison operators + // + +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_subbyte& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_subbyte& a, T const& value) +{ + a.fill(value); +} + +} // namespace cute + +// +// Specialize tuple-related functionality for cute::array_subbyte +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_subbyte& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_subbyte const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_subbyte&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct is_reference> + : CUTE_STL_NAMESPACE::true_type +{}; + + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cecdaeed35c927084f69417324908abae19148fa --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/bit_field.hpp @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Portable bit field that supports byte and word straddling that can + be used in unions to bit-wise define parameters. +*/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // uint_bit_t +#include // cute::is_same + +namespace cute +{ + +class dummy_type {}; + +template +struct bit_field +{ + static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); + + // value_type: Use the smallest value type that fits NumBits + static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : + (NumBits <= 16) ? 16 : + (NumBits <= 32) ? 32 : 64; + using value_type = cute::uint_bit_t; + // storage_type: Use the smallest storage_type that avoids boundary crossing + static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : + (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : + (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; + using storage_type = cute::uint_bit_t; + + static_assert(sizeof(OtherValueType) == sizeof(value_type) || is_same::value, + "sizeof(OtherValueType) must be same as sizeof(value_type)."); + + // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) + static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; + // Index of storage value for BitStart + static constexpr uint32_t idx = BitStart / storage_type_bits; + // Bit of data_[idx] for BitStart + static constexpr uint32_t bit_lo = BitStart % storage_type_bits; + // Number of bits in data_[idx] used for NumBits if straddling, else 0 + static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; + +public: + + // NumBits mask + static constexpr value_type mask = value_type(uint64_t(-1) >> (64u - NumBits)); + // NumBits mask for BitStart + static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; + // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 + static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; + + storage_type data_[N]; + + // Get value + CUTE_HOST_DEVICE constexpr + value_type get() const { + storage_type result = (data_[idx] & mask_lo) >> bit_lo; + if constexpr (bit_hi != 0) { + result |= (data_[idx+1] & mask_hi) << bit_hi; + } + return static_cast(result); + } + + // Set value + CUTE_HOST_DEVICE constexpr + void set(value_type x) { + storage_type item = static_cast(x & mask); + data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); + if constexpr (bit_hi != 0) { + data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); + } + } + + // Assign value + CUTE_HOST_DEVICE constexpr + bit_field& operator=(value_type x) { + set(x); + return *this; + } + + // Cast to value + CUTE_HOST_DEVICE constexpr + operator value_type () const { + return get(); + } + + // Assign OtherValueType + CUTE_HOST_DEVICE constexpr + bit_field& operator=(OtherValueType x) { + return *this = *reinterpret_cast(&x); + } + + // Cast to OtherValueType + CUTE_HOST_DEVICE constexpr + operator OtherValueType () const { + value_type x = get(); + return *reinterpret_cast(&x); + } +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5615fdefe580ecb54a75462a98179bbe81c7c881 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/cuda_types.hpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::integral_constant + +namespace cute +{ + +// +// dim3 +// + +using dim3 = ::dim3; + +// MSVC doesn't define its C++ version macro to match +// its C++ language version. This means that when +// building with MSVC, dim3 isn't constexpr-friendly. +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t& get(dim3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t const& get(dim3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t&& get(dim3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +// +// uint3 +// + +using uint3 = ::uint3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(uint3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(uint3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(uint3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e3dd6d2742033588c101663c27319226f4b8b19e --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/tuple.hpp @@ -0,0 +1,723 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::true_type, cute::false_type +#include + +#include +#include +//#include // Advanced optimizations + +// cute::tuple is like std::tuple, with differences: +// +// 1. It works on both host and device. +// 2. Its template arguments must be semiregular types. +// 3. It is always a standard-layout type if all of its template arguments are standard-layout types. +// 4. It is always an empty type if all of its template arguments are empty types. +// +// Semiregular types are default constructible and copyable. +// They include "value types" like int or float, +// but do _not_ include references like int& or float&. +// (See std::tie for an example of a tuple of references.) +// +// Standard-layout types preserve ABI across host-device boundaries. They are safe to use as device kernel parameters. +// The standard-layout requirement prevents a more common EBO-based implemented of cute::tuple. +// +// The cute::tuple is also simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. + +namespace cute +{ + +template +struct tuple; + +namespace eso +{ + +// ESO stands for "empty structure optimization." +// We use this technique to ensure that cute::tuple doesn't waste space +// storing template arguments that have no data (like integral_constant). +// Empty types in the template argument list are not even constructed, +// and do not have unique element addresses. Calling `get` +// constructs and returns an instance of an empty type on demand. + +template +struct ESO; + +template +static constexpr bool is_first_empty_v = cute::is_empty::value; +template +static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); + +template +using ESO_t = ESO, is_rest_empty_v, T...>; + +// Empty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&...) {} +}; + +// NonEmpty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&...) : first_{first} {} + + First first_; +}; + +// Empty First and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&... rest) : rest_{rest...} {} + + ESO_t rest_; +}; + +// NonEmpty T and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{}, rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} + + First first_; + ESO_t rest_; +}; + +// Get Nth value from ESO +template +CUTE_HOST_DEVICE constexpr +R +getr(S&& s) noexcept +{ + if constexpr (N == 0) { + return static_cast(s).first_; + } else { + return getr(static_cast(s).rest_); + } + CUTE_GCC_UNREACHABLE; +} + +// Compilers disagree on decltype(auto), so these implementations avoid it at cost +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> const&> +getv_cr(ESO const& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> const&, N>(s); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &> +getv_r(ESO& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> &, N>(s); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +cute::conditional_t>>::value, + cute::tuple_element_t>, + cute::tuple_element_t> &&> +getv_rr(ESO&& s) noexcept +{ + if constexpr (cute::is_empty>>::value) { + return {}; + } else { + return getr> &&, N>(static_cast&&>(s)); + } + CUTE_GCC_UNREACHABLE; +} + +} // end namespace eso + +template +struct tuple : eso::ESO_t +{ + CUTE_HOST_DEVICE constexpr + tuple() {} + + CUTE_HOST_DEVICE constexpr + tuple(T const&... t) : eso::ESO_t(t...) {} +}; + +template <> +struct tuple<> {}; + +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + +// Returns the element in the ith position of the tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple const& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_cr(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_r(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple&& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return eso::getv_rr(static_cast&&>(t)); +} + +// Returns the first position of type X (as a static integer) in the tuple +// type's argument list. +template +CUTE_HOST_DEVICE constexpr +auto +find(tuple const&) noexcept +{ + return cute::C...>>{}; +} + +// +// Custom is_tuple trait simply checks the existence of tuple_size +// and assumes get(.), tuple_element +// +namespace detail { + +template +auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; + +template +static constexpr bool is_tuple_v = cute::is_tuple::value; + +// +// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, +// just like std::tuple_cat for std::tuple. +// + +#if 0 +// Original implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); +} +#endif + +#if 1 +// Extended implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, + index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, + index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, + index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); +} + +template +struct tuple_cat_static; + +template +struct tuple_cat_static, tuple> { + using type = tuple; +}; + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + if constexpr (is_static::value && is_static::value && + is_tuple::value && is_tuple::value) { + return typename detail::tuple_cat_static::type{}; + } else { + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) +{ + return detail::tuple_cat(t0, t1, t2, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) +{ + return detail::tuple_cat(t0, t1, t2, t3, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) +{ + return detail::tuple_cat(t0, t1, t2, t3, t4, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); +} +#endif + +#if 0 +// Outer-Inner indexing trick to concat all tuples at once + +namespace detail { + +template +struct tuple_cat_helper +{ + static constexpr cute::array ns = {Ns...}; + + static constexpr size_t total_size() { + size_t sum = 0; + for (size_t n : ns) sum += n; + return sum; + } + static constexpr size_t total_size_ = total_size(); + + static constexpr auto values() { + cute::array outer_inner = {}; + + size_t idx = 0; + for (size_t i = 0; i < ns.size(); ++i) { + for (size_t j = 0; j < ns[i]; ++j, ++idx) { + outer_inner[idx][0] = i; + outer_inner[idx][1] = j; + } + } + return outer_inner; + } + static constexpr auto outer_inner_ = values(); + + using total_sequence = make_index_sequence; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuple const& t, index_sequence) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuples const&... ts) +{ + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); +} +#endif + +// +// Equality operators +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +equal_impl(TupleA const& a, TupleB const& b, index_sequence) +{ + return (cute::true_type{} && ... && (get(a) == get(b))); +} + +} // end namespace detail + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + if constexpr (tuple_size::value == tuple_size::value) { + return detail::equal_impl(t, u, make_index_sequence::value>{}); + } else { + return cute::false_type{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return cute::false_type{}; +} + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return !(t == u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return cute::true_type{}; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] +// -- colexicographical comparison [reverse, reflected, revref] +// -- element-wise comparison [any,all] +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// That said, see int_tuple for more explicitly named common comparison ops. +// + +// +// Display utilities +// + +namespace detail { + +template +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + using cute::print; + if (sizeof...(Is) == 0) { + print(s); + } else { + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); + } + print(e); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + if (sizeof...(Is) == 0) { + os << s; + } else { + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + } + return os << e; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace detail + +template ::value)> +CUTE_HOST_DEVICE void print(Tuple const& t) +{ + return detail::print_tuple(t, make_index_sequence::value>{}); +} + +#if !defined(__CUDACC_RTC__) +template ::value)> +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) +{ + return detail::print_tuple_os(os, t, make_index_sequence::value>{}); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ff9498eb7162ddf5ecfb8f4826dfe142fb6fb164 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/container/type_list.hpp @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE +#include + +namespace cute +{ + +template +struct type_list {}; + +// get for type_list +// Get an instance of the Ith type in the pack T... +// Requires tuple_element_t> to have std::is_default_constructible +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list const&) noexcept { + return {}; +} + +// Find the index of the first true in the pack B... +template +struct find_true { + CUTE_HOST_DEVICE static constexpr size_t find() { + size_t i = 0; + (void) ((B ? true : (++i, false)) || ...); + return i; + } + static constexpr size_t value = find(); +}; + +template +static constexpr size_t find_true_v = find_true::value; + +// find for type_list +// Finds the first position of type X (as a static integer) in the T... pack +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::integral_constant...>> +find(type_list const&) noexcept { + return {}; +} + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::type_list +// +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(tuple) +#else +#include +#endif + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f9d70004450a138e286cb9a12390996a01c3b035 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/int_tuple.hpp @@ -0,0 +1,917 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::array +#include // cute::is_tuple +#include // cute::Int +#include // cute::seq +#include // cute::transform + +/** IntTuple is an integer or a tuple of IntTuples. + * This file holds utilities for working with IntTuples, + * but does not hold a concrete concept or class of IntTuple. + */ + +namespace cute +{ + +// Implementation of get<0>(Integral). +// Even though is_tuple is false and tuple_size doesn't compile, +// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +// Custom recursive get for anything that implements get(.) (for a single integer I). +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + return get(get(static_cast(t))); +} + +// +// rank +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using rank_t = decltype(rank(declval())); + +template +static constexpr auto rank_v = rank_t::value; + +// +// shape +// + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return transform(s, [](auto const& a) { return shape(a); }); + } else { + return s; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return shape(get(s)); + } else { + return get(shape(s)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max +// + +template +CUTE_HOST_DEVICE constexpr +auto +max(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::max(t0, cute::max(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// min +// + +template +CUTE_HOST_DEVICE constexpr +auto +min(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::min(t0, cute::min(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// gcd +// + +template +CUTE_HOST_DEVICE constexpr +auto +gcd(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::gcd(t0, cute::gcd(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// depth +// + +template +CUTE_HOST_DEVICE constexpr +auto +depth(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using depth_t = decltype(depth(declval())); + +template +static constexpr auto depth_v = depth_t::value; + +// +// product +// + +// Implementation of product as a function object +struct Product +{ + template + CUTE_HOST_DEVICE constexpr + auto + operator()(IntTuple const& a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); + } + } else if constexpr (cute::is_integral::value) { + return a; + } + + CUTE_GCC_UNREACHABLE; + } +}; +// Callable product function object +CUTE_INLINE_CONSTANT Product product; + +// Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) +template +CUTE_HOST_DEVICE constexpr +auto +product_each(Tuple const& t) +{ + return transform(wrap(t), product); +} + +// Take the product of Tuple at the leaves of TupleG +template +CUTE_HOST_DEVICE constexpr +auto +product_like(Tuple const& tuple, TupleG const& guide) +{ + return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); +} + +// Return the product of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(IntTuple const& a) +{ + if constexpr (sizeof...(Is) == 0) { + return product(a); + } else { + return size(get(a)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +static constexpr auto size_v = decltype(size(declval()))::value; + +// +// sum +// + +template +CUTE_HOST_DEVICE constexpr +auto +sum(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// inner_product +// + +template +CUTE_HOST_DEVICE constexpr +auto +inner_product(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, + [](auto const&... v) { return (Int<0>{} + ... + v); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// ceil_div +// + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { // tuple int + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return ceil_div(a, product(b)); + } else { + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// round_up +// Round @a a up to the nearest multiple of @a b. +// + +template +CUTE_HOST_DEVICE constexpr +auto +round_up(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); + } else { + return ((a + b - Int<1>{}) / b) * b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes + * Case Tuple Tuple: + * Perform shape_div element-wise + * Case Tuple Int: + * Fold the division of b across each element of a + * Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) + * Case Int Tuple: + * Return shape_div(a, product(b)) + * Case Int Int: + * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible + * Return ceil_div(a, b) + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); + } else { // tuple int + auto [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else { + // Strong divisibility condition + //static_assert((IntTupleA::value % IntTupleB::value == 0) or (IntTupleB::value % IntTupleA::value == 0), "Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + static_assert(((IntTupleA::value % IntTupleB::value) == 0) or ((IntTupleB::value % IntTupleA::value) == 0), "Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((a % b) == 0) or ((a % b) == 0)) && "Divisibility Condition"); + } + + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Return a tuple the same profile as A scaled by corresponding elements in B + */ +template +CUTE_HOST_DEVICE constexpr +auto +elem_scale(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); + } else { + return a * product(b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Test if two IntTuple have the same profile (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +congruent(IntTupleA const& a, IntTupleB const& b) +{ + return bool_constant::value>{}; +} + +template +using is_congruent = decltype(congruent(declval(), declval())); + +/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + * weakly_congruent is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_congruent(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return true_type{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_congruent(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); + +/** Test if Shape A is compatible with Shape B: + * the size of A and B are the same, and + * any coordinate into A can also be used as a coordinate into B + * Equivalently, the size of Shape B is the same as Shape A at each terminal of Shape A. + * compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a == size(b); + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_compatible = decltype(compatible(declval(), declval())); + +/** Test if Shape A is evenly divided by Tiler B + * @returns Static or dynamic boolean + * @post if result is true_type, then + * size(a) == logical_divide(make_layout(shape(a)),b) will always compile + * and result in true_type. + */ +template +CUTE_HOST_DEVICE constexpr +auto +evenly_divides(Shape const& a, Tiler const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (rank_v > rank_v) { + return false_type{}; + } else { + return transform_apply(b, a, [](auto const& x, auto const& y) { return evenly_divides(y,x); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else { + return size(a) == size(b) * size(ceil_div(shape(a), b)); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> + */ +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); + } else if constexpr (is_constant<0, IntTupleA>::value) { + return repeat_like(b, Int<1>{}); + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tuple const& t) +{ + return filter_zeros(t, t); +} + +// +// Static sorting utilities in detail:: +// + +namespace detail { + +// Some compilers fail to constexpr evaluate quick_sort +// template +// constexpr cute::array quick_sort(cute::array a, int lo = 0, int hi = N-1) { +// if (hi <= lo) return; +// int p = lo; +// for (int i = lo; i < hi; ++i) { +// if (a[i] < a[hi]) { +// T tmp = a[p]; a[p] = a[i]; a[i] = tmp; +// ++p; +// } +// } +// T tmp = a[p]; a[p] = a[hi]; a[hi] = tmp; +// a = quick_sort(a, lo, p-1); +// a = quick_sort(a, p+1, hi); +// return a; +// } + +template +constexpr cute::array exchange_sort(cute::array a) { + for (size_t i = 0; i < N; ++i) { + for (size_t j = i+1; j < N; ++j) { + if (a[j] < a[i]) { + T tmp = a[j]; a[j] = a[i]; a[i] = tmp; + } + } + } + return a; +} + +template >> +struct Sort : Sort, to_seq_t> {}; + +template +struct Sort, seq> { + static_assert(sizeof...(Vs) == sizeof...(Is)); + static constexpr cute::array orig_array = {Vs...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using type = seq; +}; + +struct kvpair { + int key, val; + constexpr bool operator<(kvpair const& o) const { return key < o.key; }; +}; + +template >> +struct SortByKey : SortByKey, to_seq_t, to_seq_t> {}; + +template +struct SortByKey, seq, seq> { + static_assert(sizeof...(Ks) == sizeof...(Vs)); + static_assert(sizeof...(Ks) == sizeof...(Is)); + static constexpr cute::array orig_array = {kvpair{Ks,Vs}...}; + static constexpr cute::array sort_array = exchange_sort(orig_array); + using key_type = seq; + using val_type = seq; +}; + +} // end namespace detail + +// +// Converters and constructors with arrays and params +// + +/** Make an IntTuple of rank N from an Indexable array. + * Access elements up to a dynamic index n, then use init (requires compatible types) + * Consider cute::take if all indexing is known to be valid + * \code + * std::vector a = {6,3,4}; + * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_int_tuple(Indexable const& t, int n, T const& init) +{ + static_assert(N > 0); + if constexpr (N == 1) { + return 0 < n ? t[0] : init; + } else { + return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Fill the dynamic values of a Tuple with values from another Tuple + * \code + * auto params = make_tuple(6,3,4); + * cute::tuple, cute::tuple>, int, Int<2>> result; + * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +fill_int_tuple_from(Tuple& result, TupleV const& vals) +{ + return fold(result, vals, [](auto const& init, auto&& r) { + if constexpr (is_static>::value) { // Skip static elements of result + return init; + } else if constexpr (is_tuple>::value) { // Recurse into tuples + return fill_int_tuple_from(r, init); + } else { // Assign and consume arg + static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); + r = get<0>(init); + return remove<0>(init); + } + + CUTE_GCC_UNREACHABLE; + }); +} + +/** Make a "Tuple" by filling in the dynamic values in order from the arguments + * \code + * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; + * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +Tuple +make_int_tuple_from(Ts const&... ts) +{ + Tuple result = Tuple{}; + fill_int_tuple_from(result, cute::make_tuple(ts...)); + return result; +} + +/** Convert a tuple to a flat homogeneous array of type T + * \code + * auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * cute::array result = to_array(tup); // [1,6,3,3,4,2] + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +to_array(IntTuple const& t) +{ + auto flat_t = flatten_to_tuple(t); + constexpr int N = tuple_size::value; + cute::array result; + for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); + return result; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout +// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout +// -- element-wise comparison [any,all] : +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// When actually desiring to order coordinates, the user should map them to +// their indices within the Layout they came from: +// e.g. layoutX(coordA) < layoutX(coordB) +// That said, we implement the three most common ways to compare tuples below. +// These are implemented with slighly more explicit names than op<. +// + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + constexpr size_t A = tuple_size::value - 1 - I; + constexpr size_t B = tuple_size::value - 1 - I; + return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return elem_less(get(a), get(b)) && elem_less_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Lexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::lex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_leq(T const& t, U const& u) { + return !lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_gtr(T const& t, U const& u) { + return lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_geq(T const& t, U const& u) { + return !lex_less(t, u); +} + +// Colexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::colex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_leq(T const& t, U const& u) { + return !colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_gtr(T const& t, U const& u) { + return colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_geq(T const& t, U const& u) { + return !colex_less(t, u); +} + +// Elementwise [all] comparison + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::elem_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_leq(T const& t, U const& u) { + return !elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_gtr(T const& t, U const& u) { + return elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_geq(T const& t, U const& u) { + return !elem_less(t, u); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cb161369cbb6b2d2961a126216dfdb3716d4e501 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout.hpp @@ -0,0 +1,1931 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include // cute::sizeof_bits + +namespace cute +{ + +// Aliases + +template +using Shape = cute::tuple; + +template +using Stride = cute::tuple; + +template +using Step = cute::tuple; + +template +using Coord = cute::tuple; + +template +using Tile = cute::tuple; + +template +CUTE_HOST_DEVICE constexpr +Shape +make_shape(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Stride +make_stride(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Step +make_step(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Coord +make_coord(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Tile +make_tile(Ts const&... t) +{ + return {t...}; +} + +// +// Layout +// + +template > +struct Layout + : private cute::tuple // EBO for static layouts +{ + // Expensive in compilation time... + //static_assert(is_congruent::value, "Shape and Stride must be congruent"); + + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CUTE_HOST_DEVICE constexpr + Layout(Shape const& shape = {}, Stride const& stride = {}) + : cute::tuple(shape, stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() { + return get<0,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return get<0,I...>(static_cast const&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() { + return get<1,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return get<1,I...>(static_cast const&>(*this)); + } + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return crd2idx(coord, shape(), stride()); + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE: Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return cute::idx2crd(idx, shape(), stride()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return cute::crd2idx(this->get_hier_coord(idx), shape()); + } + + // + // Coordinate to Coordinate + // + +#if 0 + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post congruent(@a result, shape()) + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_hier_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), shape()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_flat_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), product_each(shape())); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post is_integral::value + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_1d_coord(Coord const& crd) const { + //return cute::crd2crd(crd, shape(), product(shape())); + return cute::crd2idx(crd, shape()); + } +#endif +}; + +// Equality, return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +// +// Layout construction +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, Stride const& stride) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + return Layout(shape, stride); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape) +{ + static_assert(is_tuple::value || is_integral::value); + return make_layout(shape, compact_major(shape)); +} + +// +// Convenience tags for common layouts +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutLeft) +{ + return make_layout(shape, compact_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutRight) +{ + return make_layout(shape, compact_major(shape)); +} + +// +// Construct a layout from multiple layouts by concatenation +// + +// One argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0) +{ + return make_layout(make_shape (layout0.shape() ), + make_stride(layout0.stride())); +} + +// Two argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() ), + make_stride(layout0.stride(), layout1.stride())); +} + +// Var argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1, + Layout const&... layouts) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() , layouts.shape()... ), + make_stride(layout0.stride(), layout1.stride(), layouts.stride()...)); +} + +// +// Advanced Layout constructions +// + +// Make a compact layout with shape @a shape and strides following the order induced by @a order. +// Dynamic values in @a order are ignored, considered large, and considered ordered from left to right. +// Example: +// make_ordered_layout(Shape<_2,_2,_2,_2>{}, Step<_0,_2,_3,_1>{}) +// -> (_2,_2,_2,_2):(_1,_4,_8,_2) +// make_ordered_layout(make_shape(2,3,4,5), make_step(Int<2>{}, 67, 42, Int<50>{})) +// -> (2,3,4,5):(_1,10,30,2) +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Shape const& shape, Order const& order) +{ + return make_layout(shape, compact_order(shape, order)); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(). +// Static-0 strides in the input @a layout are preserved in the output. +// Example: +// make_layout_like(Layout, Stride<_0,_2,_4,_1>>{}) +// -> (_2,_2,_2,_2):(_0,_2,_4,_1) +// make_layout_like(make_layout(make_shape(2,3,4,5), make_stride(Int<0>{},42,Int<1>{},Int<0>{}))) +// -> (2,3,4,5):(_0,4,_1,_0) +template +CUTE_HOST_DEVICE constexpr +auto +make_layout_like(Layout const& layout) +{ + return make_layout(layout.shape(), + compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride())); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(), +// except mode-0 is always stride-1 and generated column-major. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms so this +// generates the 0th mode with LayoutLeft (preserving stride-0s) regardless of the reference layout +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + constexpr int R = Layout::rank; + if constexpr (R > 1 && is_static::value) { + return tiled_product(make_layout(get<0>(layout.shape()), + compact_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), + make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Shape const& shape) +{ + return make_layout(shape); +} + +// +// Make an identity layout that maps a coordinate to itself +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_layout(Shape const& shape) +{ + return make_layout(shape, make_basis_like(shape)); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +// Return the Is...th sublayout. +// For Is... = , equivalent to get(...get(get(layout))) +template +CUTE_HOST_DEVICE constexpr +auto +get(Layout const& layout) +{ + return make_layout(get(layout.shape()), + get(layout.stride())); +} + +// Return a new layout with only the modes in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(Layout const& layout) +{ + static_assert(B < E, "take: empty range error"); + static_assert(0 <= B && E <= Layout::rank, "take: range out of bounds"); + return make_layout(take(layout.shape()), + take(layout.stride())); +} + +// Return a new layout with only the modes Is... = +template +CUTE_HOST_DEVICE constexpr +auto +select(Layout const& layout) +{ + return make_layout(select(layout.shape()), + select(layout.stride())); +} + +// Return a layout with depth at most 1 +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Layout const& layout) +{ + return make_layout(flatten(layout.shape()), + flatten(layout.stride())); +} + +// Return a layout whose profile is congruent to TargetProfile +// @pre Input layout is flat, flatten(@a layout) == @a layout +// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile)) +// @post congruent(@a result, @a target_profile) +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(Layout const& layout, TargetProfile const& target_profile) +{ + return make_layout(unflatten(layout.shape(), target_profile), + unflatten(layout.stride(), target_profile)); +} + +// +// Utilities +// + +// Return the sublayout of mode I... +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Layout const& layout) +{ + if constexpr (sizeof...(Is) == 0) { + return layout; + } else { + return get(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout& layout) +{ + return layout.template shape(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout const& layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout& layout) +{ + return layout.template stride(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout const& layout) +{ + return layout.template stride(); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Layout const& layout) +{ + return size(shape(layout)); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(Layout const& layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(Layout const& layout) +{ + return depth(shape(layout)); +} + +// Return the coprofile of a mode as a tuple of _0s +// @post congruent(coprofile(@a layout), @a layout(i)) for any i +// @return T Tuple that is congruent with the codomain of @a a. +template +CUTE_HOST_DEVICE constexpr +auto +coprofile(Layout const& layout) +{ + return repeat_like(as_arithmetic_tuple(sum(stride(layout))), Int<0>{}); +} + +// Return the codomain shape of a mode +// @post size(coshape(@a layout)) == cosize(@a layout) +// @return C Coordinate with smallest elements such that +// elem_less(@a sub_layout(c), C) for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +coshape(Layout const& layout) +{ + auto m1_shapes = transform_leaf( shape(layout), [](auto s) { return s - Int<1>{}; }); + auto abs_strides = transform_leaf(stride(layout), abs_fn{}); + auto co_coord = as_arithmetic_tuple(inner_product(m1_shapes, abs_strides)); + return transform_leaf(co_coord, [](auto c) { return c + Int<1>{}; }); +} + +// Return the codomain size of a mode +// @return M smallest integer such that +// size(@a sub_layout(c)) < M for all c < size(@a sub_layout) +// where @a sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + return size(coshape(layout)); +} + +template +using cosize_t = decltype(cosize(declval())); + +template +static constexpr auto cosize_v = cosize_t::value; + +// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& c, Layout const& layout) +{ + return crd2idx(c, layout.shape(), layout.stride()); +} + +// +// Slice and Dice a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& c, Layout const& layout) +{ + return make_layout(slice(c, layout.shape()), + slice(c, layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, Layout const& layout) +{ + return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +dice(Coord const& c, Layout const& layout) +{ + return make_layout(dice(c, layout.shape()), + dice(c, layout.stride())); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// This exists so it can be overloaded for ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Layout const& layout) +{ + return cute::make_tuple(layout, layout(coord)); +} + +// +// Transform the modes of a layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f, seq) +{ + return make_layout(f(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) +{ + return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f) +{ + return detail::transform_layout(t, f, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) +{ + constexpr int R0 = decltype(rank(t0))::value; + constexpr int R1 = decltype(rank(t1))::value; + constexpr int R = (R0 < R1) ? R0 : R1; + return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); +} + +// +// Coalesce and Filter +// + +namespace detail { + +// Look at each element and the front of the stack (in order of priority) +// front(NewLayout) get(Layout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_front s1:d1 +// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 +// s0:d0 s1:d1 => prepend s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == -1) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return bw_coalesce(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_static(new_shape))>::value && + is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + // Merge modes because the shapes and strides match + return bw_coalesce(old_shape, old_stride, + replace_front(new_shape, get(old_shape) * get<0>(new_shape)), + replace_front(new_stride, get(old_stride))); + } else { + // Can't replace or merge, so prepend a new mode + return bw_coalesce(old_shape, old_stride, + prepend(new_shape, get(old_shape)), + prepend(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// cute::coalesce promises to not change the Layout as a function from integers to codomain. +// It accomplishes this inside of the Layout's domain, but not always outside of the domain. +// Example: (_4,_1):(_1,_0) coalesces to _4:_1. +// detail::coalesce_x preserves the Layout function inside its domain and outside. +// +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i, @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + if constexpr (is_constant<1, decltype(get(flat_shape))>::value) { + return detail::bw_coalesce(flat_shape, flat_stride, Int<2>{}, get(flat_stride)); + } else { + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Apply coalesce_x at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); }); + } else { + return coalesce_x(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// "Simplify" the layout by combining modes that are possible to combine +// Does not respect the shape of the layout, but does preserve total size +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); +} + +// Apply coalesce at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); + } else { + return coalesce(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine static and dynamic modes of a shape. +// @post size(@a result) == size(@a shape) +// @post depth(@a result) <= 1 +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Shape const& shape) +{ + static_assert(is_integral::value || is_tuple::value); + + return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) { + if constexpr (is_static::value == is_static::value) { + return replace_back(init, back(init) * a); // Both static or both dynamic, coalesce and replace + } else { + return append(init, a); // Can't coalesce, so append + } + + CUTE_GCC_UNREACHABLE; + }); +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout) +{ + return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); +} + +// Replace the modes in layout that correspond to a 0 at the terminals of trg_profile with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout, IntTuple const& trg_profile) +{ + return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride()); +} + +// Remove all of the 0-strides and 1-sizes +// Return 1-shape if empty +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout) +{ + return coalesce(filter_zeros(layout)); +} + +// Apply filter at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); + } else { + return filter(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Append, Prepend, Replace +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +replace(Layout const& layout, + Layout const& x) +{ + return make_layout(replace(layout.shape(), x.shape()), + replace(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(Layout const& layout) +{ + return make_layout(group(layout.shape()), + group(layout.stride())); +} + +// +// Composition of two layouts: lhs o rhs +// @post compatible(rhs, result) +// @post result(c) = lhs(rhs(c)) +// for all c in the domain of rhs +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +composition_impl(LShape const& lhs_shape, [[maybe_unused]] LStride const& lhs_stride, + RShape const& rhs_shape, RStride const& rhs_stride) +{ + if constexpr (is_tuple::value) { // Right-distributivity of Layout composition for RHS tuple + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { + return composition_impl(lhs_shape, lhs_stride, s, d); + }); + } else + if constexpr (is_scaled_basis::value) { // Special case for a RHS ScaledBasis stride + return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), + rhs_shape, basis_value(rhs_stride)); + } else + if constexpr (is_constant<0, RStride>::value) { // Special case shortcut for any RHS static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { // Special case shortcut for any LHS integral shape + return Layout{rhs_shape, rhs_stride * lhs_stride}; + } else { // General case: LHS tuple, RHS integral + constexpr int R = tuple_size::value; + + auto [result_shape, result_stride, rest_shape, rest_stride] = + cute::fold(make_seq{}, // t = [0,1,2,...,R-1) + cute::make_tuple(cute::tuple<>{}, // v = (result_shape, + cute::tuple<>{}, // result_stride, + rhs_shape, // rest_shape:Integral, + rhs_stride), // rest_stride:Integral) + [&](auto const& init, auto curr_i) { // f(v,t) -> v' + // Can ICE on some compilers + //auto [result_shape, result_stride, rest_shape, rest_stride] = init; + //auto [curr_shape, curr_stride] = curr; + // Unpack inputs + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + auto rest_shape = get<2>(init); + auto rest_stride = get<3>(init); + + auto curr_shape = get(lhs_shape); + [[maybe_unused]] auto curr_stride = get(lhs_stride); + + // Strong divisibility condition -- requires composition to be statically verifiable. + //CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + + // Weak divisibility condition -- verify the divisibility condition whenever possible + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_stride % curr_shape) == Int<0>{}) or (rest_stride < curr_shape), "Stride Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert((((rest_stride % curr_shape) == 0) or (rest_stride < curr_shape)) && "Stride Divisibility Condition"); + } + + // next_shape: ceil(exclusive_prefix_product(lhs_shape) / rhs_stride) + [[maybe_unused]] auto next_shape = cute::ceil_div(curr_shape, abs(rest_stride)); + // next_stride: ceil(rhs_stride / exclusive_prefix_product(lhs_shape)) + [[maybe_unused]] auto next_stride = cute::ceil_div(abs(rest_stride), curr_shape) * signum(rest_stride); + + if constexpr (is_constant<1, decltype(next_shape)>::value or is_constant<1, decltype(rest_shape)>::value) { + return cute::make_tuple(result_shape, + result_stride, + rest_shape, + next_stride); + } else { + auto new_shape = cute::min(next_shape, rest_shape); + + // Strong divisibility condition + //CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + + // Weak divisibility condition + if constexpr (is_static::value and is_static::value) { + CUTE_STATIC_ASSERT_V(((rest_shape % new_shape) == Int<0>{}), "Shape Divisibility Condition"); + } else { + // DEBUG assert can cause extra registers and inappropriate compile-time/run-time failure + //assert(((rest_shape % new_shape) == 0) && "Shape Divisibility Condition"); + } + + return cute::make_tuple(append(result_shape, new_shape), + append(result_stride, rest_stride * curr_stride), + rest_shape / new_shape, + next_stride); + } + + CUTE_GCC_UNREACHABLE; + }); + + if constexpr (tuple_size::value == 0) { + return Layout{rest_shape, rest_stride * get(lhs_stride)}; + } else + if constexpr (is_constant<1, decltype(rest_shape)>::value) { + return Layout{unwrap(result_shape), unwrap(result_stride)}; + } else { + return Layout{append(result_shape, rest_shape), + append(result_stride, rest_stride * get(lhs_stride))}; + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Layout const& rhs) +{ + auto flat_lhs = detail::coalesce_x(lhs, coprofile(rhs)); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Tiler const& rhs) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + // Drop any modes of lhs that aren't hit by rhs + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { + return lhs; + } else if constexpr (is_integral::value) { + auto flat_lhs = detail::coalesce_x(lhs); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Complement +// +// Build the complement of a layout. +// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); +// @post For all i in [1,size(@a result)), +// @a result(i) < @a result(i-1) +// For all j in [0, size(@a layout)), +// @a result(i) != @a layout(j) +// + +namespace detail { + +// @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). +template +CUTE_HOST_DEVICE constexpr +auto +complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) +{ + if constexpr (is_constant<0, Stride>::value) { + // Special case for irreducible rank-1 stride-0 layout + return make_layout(coalesce(cotarget)); + } else { + // General case + constexpr int R = rank_v; + static_assert(R == 1 || is_static::value, + "Dynamic-stride complement only for rank-1 layouts"); + + // Should just be a sort and a fold... + // Then we could even handle dynamic strides (but they would destroy all static strides) + auto [shape_, stride_, result_shape_, result_stride] = + fold(make_seq{}, + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto [shape, stride, result_shape, result_stride] = init; + auto min_stride = cute::min(stride); + auto min_idx = cute::find(stride, min_stride); + auto new_shape = min_stride / get(result_stride); + auto new_stride = min_stride * get(shape); + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , new_shape ), // new shape = min_stride / last_stride + append(result_stride, new_stride)); // new stride = min_stride * curr_shape + }); + + // Append the last shape mode + auto new_shape = get<0>(stride_) / get(result_stride); // new shape = min_stride / last_stride + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + auto result_shape = append(result_shape_, new_shape); + + // Compute the rest_shape and rest_stride + auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape + auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); + auto rest_stride = compact_major(rest_shape, new_stride); + + // Coalesce and append (rest_shape, rest_stride) + return coalesce(make_layout(make_shape (result_shape , rest_shape ), + make_stride(result_stride, rest_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoTarget const& cotarget) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout)); +} + +// +// Right-Inverse and Left-Inverse +// + +// +// Build the right-inverse of a layout +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(i)) == i for all i < size(@a result) +// @result A layout @a result such that +// composition(@a layout, @a result) is identical to make_layout(shape(result)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Layout const& layout) +{ + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); + + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); + + // Filter out any dynamic strides + [[maybe_unused]] auto filtered_seq = filter_tuple(make_seq{}, lstride, [](auto i, auto d) { + return conditional_return>(cute::tuple{i}, cute::tuple<>{}); }); + [[maybe_unused]] auto filtered_stride = transform(filtered_seq, [&](auto i) { return get(lstride); }); + + // Sort by strides + using Sorted = detail::SortByKey; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride, curr] = cute::fold(sorted_seq, tuple,tuple<_0>,_1>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto ishape = get(lshape); + [[maybe_unused]] auto istride = get(lstride); + [[maybe_unused]] auto curr_stride = get<2>(init); + + if constexpr (is_constant::value) { + return make_tuple(append(get<0>(init), ishape), // result_shape + append(get<1>(init), get(preprod_shape)), // result_stride + ishape * istride); + } else { + return init; + } + + CUTE_GCC_UNREACHABLE; + }); + + return coalesce(make_layout(result_shape, result_stride)); +} + +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Underscore const& _) +{ + return _; +} + +// +// Build the quasi-inverse of a layout (left-inverse when layout is injective) +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(@a layout(i))) == @a layout(i) for all i < size(@a layout) +// @result A layout @a result such that +// composition(@layout, composition(@a result, @a layout)) is identical to @a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Layout const& layout) +{ + // Flatten and filter shape-1 + auto clayout = coalesce(layout); + auto lstride = wrap(clayout.stride()); + auto lshape = wrap(clayout.shape()); + + // Prefix product of the shape + auto preprod_shape = cute::fold(lshape, cute::tuple<_1>{}, [](auto c, auto vi) { return append(c, vi*back(c)); }); + + // Sort by strides + static_assert(is_static::value, "Left inverse requires static strides."); + using Sorted = detail::SortByKey>; + auto sorted_seq = typename Sorted::val_type{}; + //auto sorted_stride = typename Sorted::key_type{}; + + auto [result_shape, result_stride] = cute::fold(sorted_seq, tuple,tuple<_0>>{}, + [&](auto const& init, auto i) { + [[maybe_unused]] auto istride = get(lstride); + + if constexpr (is_constant<0, decltype(istride)>::value) { + return init; + } else { + auto result_shape = get<0>(init); + auto result_stride = get<1>(init); + + CUTE_STATIC_ASSERT_V((istride % size(result_shape)) == Int<0>{}, "Left inverse divisibility condition"); + + return make_tuple(append(result_shape, istride / size(result_shape)), + append(result_stride, get(preprod_shape))); + } + + CUTE_GCC_UNREACHABLE; + }); + + return coalesce(make_layout(append(result_shape, get(lshape)), + result_stride)); +} + +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Underscore const& _) +{ + return _; +} + +// +// Max Common Layout +// + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Layout R + * @post For all 0 <= i < size(R), a(R(i)) == i and b(R(i)) == i + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + Layout const& b) +{ + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); + } else { + return Layout<_1,_0>{}; + } +} + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b.get_1d_coord(n)) == n + * (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + Layout const& b) +{ + Layout common = coalesce(composition(a, right_inverse(b))); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that distributes ShapeB over ShapeA. + * + * @returns Layout result + * @post evenly_divides(@a b, size(@a result)) + * @post evenly_divides(@a a, @a result) + * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. + * @post composition(make_layout(shape(@a a)), @a result) is admissible + * \code + * // Note that 6 does not divide this shape + * Layout layoutA = Layout,Int<14>>>{}; + * + * // Want to tile any 6 elements and don't care where they come from + * Layout dist = domain_distribute(layoutA, Int<6>{}); // (_3,_2):(_1,_15) + * + * // Not guaranteed to find all 6 though... + * CUTE_STATIC_ASSERT_V(Int<6>{} == size(dist)); + * + * Layout result = zipped_divide(layoutA, dist); // (_6,Rest) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +domain_distribute(ShapeA const& a, ShapeB const& b) +{ + static_assert(is_integral::value); + static_assert(is_static::value); + + auto flat_shape_a = flatten(shape(a)); + + static_assert(is_static::value); + + // Compute the shape of the result + auto [result_shape, b_rest] = cute::fold(flat_shape_a, cute::make_tuple(cute::tuple<>{}, size(b)), [](auto init, auto a_) { + auto [result, b_] = init; + auto gcd_ = gcd(a_, b_); + return cute::make_tuple(append(result, gcd_), b_ / gcd_); + }); + + // Compute the stride of the result + auto result_stride = compact_major(flat_shape_a); + + return coalesce(make_layout(result_shape, result_stride)); +} + +// +// Kernel (Nullspace) of a Layout +// + +/** Return a layout that represents the nullspace of @a layout + * @post @a layout(@a result(i)) == 0 for all i < size(@a result) + * @post nullspace(@a result) == Layout<_1,_0>{} + * @post size(@a result) == size(@a layout) / size(filter(@a layout)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(Layout const& layout) +{ + [[maybe_unused]] auto flat_stride = flatten(layout.stride()); + + // Select all indices corresponding to stride-0s + auto iseq = cute::fold(make_seq>{}, cute::tuple<>{}, + [&](auto init, auto i){ + if constexpr (is_constant_v<0, decltype(get(flat_stride))>) { return append(init, i); } + else { return init; } + CUTE_GCC_UNREACHABLE; + }); + + if constexpr (tuple_size::value == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto flat_shape = flatten(layout.shape()); + auto rstride = compact_major(flat_shape); + return make_layout(unwrap(transform(iseq, [&](auto i) { return get(flat_shape); })), + unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Zip +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layout) +{ + return make_layout(zip(layout.shape()), + zip(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layoutA, + Layout const& layoutB) +{ + return make_layout(zip(layoutA.shape(), layoutB.shape()), + zip(layoutA.stride(), layoutB.stride())); +} + +// +// Tile unzip +// Logical product and logical divide (on layouts) produce rank-2 results by design. +// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into +// their own mode. +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(Layout const& layout, + Tiler const& tiler) +{ + return make_layout(zip2_by(layout.shape(), tiler), + zip2_by(layout.stride(), tiler)); +} + +// +// Logical divide +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Layout const& tiler) +{ + return composition(layout, make_layout(tiler, complement(tiler, shape(coalesce(layout))))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Generalization of ceil_div for Layout lhs +// is effectively the "rest mode" of logical_divide. +// Occurs in the calculation of gridDim, for example, for generalized tilers +// Example: +// dim3 gridDim(size(ceil_div(problem_shape_M, cta_tiler_M)), +// size(ceil_div(problem_shape_N, cta_tiler_N))); +// This does not consider compositional acceptance, so it may be the case that +// ceil_div produces a result while logical_divide (and friends) do not. +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(Target const& target, + Layout const& tiler) +{ + return shape(complement(tiler, shape(target))); +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the tile modes and residuals into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Layout const& layout, + Tiler const& tiler) +{ + return tile_unzip(logical_divide(layout, tiler), tiler); +} + +// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Logical product +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Layout const& tiler) +{ + return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); + return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { + return block; + } else if constexpr (is_integral::value) { + return logical_product(block, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the block modes and products into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(Layout const& block, + Tiler const& tiler) +{ + return tile_unzip(logical_product(block, tiler), tiler); +} + +// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Rank-sensitive products +// + +// blocked_product -- Reproduce a block over a tiler. +// Think of every element of "tiler" as a "block" +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return zip(get<0>(result), get<1>(result)); +} + +// raked_product -- Reproduce a block over a tiler with block-interleaving. +// Think of every element of "tiler" as a "block", interleave those blocks, +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return zip(get<1>(result), get<0>(result)); +} + +// tile_to_shape -- Perform a product of a layout so that the result matches a target shape. +// This is similar to blocked_product, but specifies the result shape instead of the +// product shape, which is more convenient in certain circumstances. +// @param block The layout to repeat +// @param trg_shape The target shape of the result +// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with. +// Defaults to GenColMajor, so @a layout will repeat +// across the first mode first, the second mode second, etc +// E.g. Step<_2,_1,_3> will cause @a layout to repeat +// across the second mode first, the first mode second, and the third mode last. +// @pre rank(@a block) <= rank(@a trg_shape) +// @post compatible(@a trg_shape, shape(@a result)) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(Layout const& block, + TrgShape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + constexpr int R = rank_v; + + auto padded_block = append(block); + + auto block_shape = product_each(shape(padded_block)); + auto target_shape = product_each(shape(trg_shape)); + + // Assert proper division + if constexpr (is_static::value) { + CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape), + "tile_to_shape: block shape does not divide the target shape."); + } + + auto product_shape = ceil_div(target_shape, block_shape); + + return blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)); +} + +// +// Upcast +// For stride-1 mode, divide size by N. Divide all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { // tuple stride + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride + return Layout{shape,stride}; + } else if constexpr (is_static::value) { // static stride + static_assert(Stride::value % N == 0 or N % Stride::value == 0, "Divisibility condition"); + return make_layout(ceil_div(shape, ceil_div(Int{}, abs(stride))), + signum(stride) * ceil_div(abs(stride), Int{})); + } else { // dynamic stride + // Assume dynamic strides are larger than N and divisible + // assert(stride % N == 0); + return make_layout(shape, safe_div(stride, Int{})); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Layout const& layout) +{ + return upcast(layout.shape(), layout.stride()); +} + +// +// Downcast +// For stride-1 mode, multiply size by N. Multiply all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); + } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { + return make_layout(shape * Int{}, stride); + } else { + return make_layout(shape, stride * Int{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Layout const& layout) +{ + CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); + return downcast(layout.shape(), layout.stride()); +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Layout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Determine the maximum alignment of a Layout. +// The maximum alignment is the largest N for which upcast(layout) will compile. +// upcast(layout) compiles when the static shapes and strides pass divisibility checks. +// Therefore, upcast(layout) will also compile for all divisors M of N. +// Note that this only considers the static shapes and strides of the Layout +// in symmetry with upcast only checking against static shapes and strides and assuming all +// dynamic shapes and strides are large and multiples of N. +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto static_shape = transform( shape(flat_layout), [](auto s){ return conditional_return::value>(s, Int<1>{}); }); + auto static_stride = transform(stride(flat_layout), [](auto d){ return conditional_return::value>(d, Int<0>{}); }); + auto filter_layout = make_layout(static_shape, static_stride); + auto permuted = logical_divide(filter_layout, right_inverse(filter_layout)); + return gcd(size<0>(permuted), stride<1>(permuted)); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Layout const& layout) +{ + print(layout.shape()); print(":"); print(layout.stride()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) +{ + return os << shape(layout) << ":" << stride(layout); +} +#endif + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6a9677831bf2b19c971223f18dc2c929b74fa4ab --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/layout_composed.hpp @@ -0,0 +1,661 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::tuple +#include // cute::true_type, cute::false_type, cute::Int + +/* This implements a ComposedLayout of the form + * LayoutA o Offset o LayoutB + * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. + * For example, when the "divisibility condition" is violated in composition(LayoutA, LayoutB). + * + * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. + * For example, this layout provides shape() and size() functions, but does not provide stride() functions. + * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only + * as LayoutB defines the domain. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(LayoutA const& layoutA = {}, + Offset const& offset = {}, + LayoutB const& layoutB = {}) + : cute::tuple(layoutA, offset, layoutB) + {} + + // + // Accessors + // + + static constexpr int rank = LayoutB::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_a() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_b() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_b().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // Equality, return a static or dynamic boolean + template + CUTE_HOST_DEVICE constexpr + auto + operator==(ComposedLayout const& other) const { + return this->layout_a() == other.layout_a() && + this->layout_b() == other.layout_b() && + this->offset() == other.offset(); + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_composed_layout(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_b()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_b()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_b()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_b()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_b()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), get(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), take(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), group(a.layout_b())); +} + +// +// Slice a ComposedLayout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout const& layout) +{ + auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, ComposedLayout const& layout) +{ + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); + + return composition(composition(a, b.layout_a()), b.layout_b()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoTarget const& cotarget) +{ + return complement(layout.layout_b(), cotarget); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), zip(a.layout_b())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); +} + + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); +} + + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(ComposedLayout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + //return gcd(max_alignment(layout.layout_a()), max_alignment(layout.offset()), max_alignment(layout.layout_b())); + return Int<1>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + return Layout<_1,_0>{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); +} +#endif + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp new file mode 100644 index 0000000000000000000000000000000000000000..016ac5b6cba80cf249698ea4b075840b8d6ed070 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/arithmetic_tuple.hpp @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct ArithmeticTuple : public tuple { + CUTE_HOST_DEVICE constexpr + ArithmeticTuple() : tuple() {} + + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(tuple const& t) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(T const&... t) : tuple(t...) {} +}; + +template +struct is_tuple> : true_type {}; + +template +struct is_flat> : is_flat> {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_arithmetic_tuple(T const&... t) { + return ArithmeticTuple(t...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(T const& t) { + if constexpr (is_tuple::value) { + return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); }, + [](auto const&... a){ return make_arithmetic_tuple(a...); }, + tuple_seq{}); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Numeric operators +// + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, tuple const& u) { + return t + ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) + u; +} + +// Subtraction +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, tuple const& u) { + return t - ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) - u; +} + +// Negation +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t) { + return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +// +// Special cases for C<0> +// + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator+(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator+(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op+ error!"); + return t; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator-(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op- error!"); + return -u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple +operator-(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op- error!"); + return t; +} + +// +// ArithmeticTupleIterator +// + +template +struct ArithmeticTupleIterator +{ + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + + ArithTuple coord_; + + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} + + CUTE_HOST_DEVICE constexpr + ArithTuple operator*() const { return coord_; } + + template + CUTE_HOST_DEVICE constexpr + auto operator[](Coord const& c) const { return *(*this + c); } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator>(coord_ + c); + } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(Ts const&... ts) { + return ArithmeticTupleIterator(as_arithmetic_tuple(ts...)); +} + +// +// ArithmeticTuple "basis" elements +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// (_0,_0,...,T,_0,...) +// with value T in the Nth mode + +template +struct ScaledBasis : private tuple +{ + CUTE_HOST_DEVICE constexpr + ScaledBasis(T const& t = {}) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) value() { return get<0>(static_cast &>(*this)); } + CUTE_HOST_DEVICE constexpr + decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + + // Deprecated: Get the first hierarchical mode in this basis. + CUTE_HOST_DEVICE static constexpr + auto mode() { return get<0>(int_sequence{}); } +}; + +// Ensure flat representation +template +struct ScaledBasis, Ns...> : ScaledBasis {}; + +template +struct is_scaled_basis : false_type {}; +template +struct is_scaled_basis> : true_type {}; + +template +struct is_integral> : true_type {}; + +// Shortcuts +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = ScaledBasis,Ns...>; + +// Apply the Ns... pack to another Tuple +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(T const&, Tuple&& t) +{ + return static_cast(t); +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(ScaledBasis const&, Tuple&& t) +{ + if constexpr (sizeof...(Ns) == 0) { + return static_cast(t); + } else { + return get(static_cast(t)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE decltype(auto) +basis_value(T const& e) { + if constexpr (is_scaled_basis::value) { + return e.value(); + } else { + return e; + } + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +to_atuple_i(T const& t, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-N+1 ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return t.value(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(ScaledBasis{t.value()}), make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_basis_like(Shape const& shape) +{ + if constexpr (is_tuple::value) { + // Generate bases for each mode of shape + return transform(tuple_seq{}, shape, [](auto I, auto si) { + // Generate bases for each si and add an i on end + return make_basis_like(si); + }); + } else { + return E{}; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Arithmetic +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(ScaledBasis const& b, U const& u) +{ + auto t = ceil_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +abs(ScaledBasis const& e) +{ + auto t = abs(e.value()); + return ScaledBasis{t}; +} + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis const& t, ScaledBasis const& u) { + if constexpr (sizeof...(Ns) == sizeof...(Ms)) { + return bool_constant<((Ns == Ms) && ...)>{} && t.value() == u.value(); + } else { + return false_type{}; + } + CUTE_GCC_UNREACHABLE; +} + +// Not equal to anything else +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(ScaledBasis const&, U const&) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(T const&, ScaledBasis const&) { + return {}; +} + +// Multiplication +template +CUTE_HOST_DEVICE constexpr +auto +operator*(A const& a, ScaledBasis const& e) { + auto r = a * e.value(); + return ScaledBasis{r}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator*(ScaledBasis const& e, B const& b) { + auto r = e.value() * b; + return ScaledBasis{r}; +} + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ScaledBasis const& u) { + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + return as_arithmetic_tuple(t) + u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + return t + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(C, ScaledBasis const& u) { + if constexpr (sizeof...(Ms) == 0) { + return C{} + u.value(); + } else { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, C) { + if constexpr (sizeof...(Ns) == 0) { + return t.value() + C{}; + } else { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) +{ + printf("ArithTuple"); print(iter.coord_); +} + +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) +{ + print(e.value()); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(printf("@%d", Ns)), 0)); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) +{ + return os << "ArithTuple" << iter.coord_; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +{ + os << e.value(); + // Param pack trick to print in reverse + [[maybe_unused]] int dummy; (dummy = ... = (void(os << "@" << Ns),0)); + return os; +} +#endif + +} // end namespace cute + + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp new file mode 100644 index 0000000000000000000000000000000000000000..3115d615179d145f6c2782aaa9e09424fdbd36ba --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/complex.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include // cutlass::complexm, cutlass::real, cutlass::imag, cutlass::is_complex + +namespace cute +{ + +using cutlass::complex; +using cutlass::is_complex; +using cutlass::RealType; +using cutlass::real; +using cutlass::imag; +using cutlass::conj; + +template +static constexpr auto is_complex_v = is_complex::value; + +/// Fused multiply-add for complex numbers +template +CUTE_HOST_DEVICE constexpr +void +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) +{ + fma(d.real(), a.real(), b.real(), c.real()); + fma(d.imag(), a.real(), b.imag(), c.imag()); + fma(d.real(), -a.imag(), b.imag(), d.real()); + fma(d.imag(), a.imag(), b.real(), d.imag()); +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(complex const& a, + complex const& b, + complex & c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cfef8cc5b237f73b145a399601742cc7ee8b6124 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/int.hpp @@ -0,0 +1,111 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once +#include "cutlass/cutlass.h" +#if defined(__CUDACC_RTC__) +#include CUDA_STD_HEADER(cstdint) +#else +#include +#endif + +#include // CUTE_STL_NAMESPACE + +#include // cutlass::int2b_t, cutlass::int4b_t + +namespace cute +{ + +// +// Signed integers +// + +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using int6_t = cutlass::int6b_t; +using CUTE_STL_NAMESPACE::int8_t; +using CUTE_STL_NAMESPACE::int16_t; +using CUTE_STL_NAMESPACE::int32_t; +using CUTE_STL_NAMESPACE::int64_t; + +template struct int_bit; +template <> struct int_bit< 2> { using type = int2_t; }; +template <> struct int_bit< 4> { using type = int4_t; }; +template <> struct int_bit< 8> { using type = int8_t; }; +template <> struct int_bit< 16> { using type = int16_t; }; +template <> struct int_bit< 32> { using type = int32_t; }; +template <> struct int_bit< 64> { using type = int64_t; }; + +template +using int_bit_t = typename int_bit::type; + +template +using int_byte = int_bit<8*N>; + +template +using int_byte_t = typename int_byte::type; + +// +// Unsigned integers +// + +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using uint6_t = cutlass::uint6b_t; +using CUTE_STL_NAMESPACE::uint8_t; +using CUTE_STL_NAMESPACE::uint16_t; +using CUTE_STL_NAMESPACE::uint32_t; +using CUTE_STL_NAMESPACE::uint64_t; +using cutlass::uint128_t; +using cutlass::uint256_t; + +template struct uint_bit; +template <> struct uint_bit< 1> { using type = uint1_t; }; +template <> struct uint_bit< 2> { using type = uint2_t; }; +template <> struct uint_bit< 4> { using type = uint4_t; }; +template <> struct uint_bit< 6> { using type = uint6_t; }; +template <> struct uint_bit< 8> { using type = uint8_t; }; +template <> struct uint_bit< 16> { using type = uint16_t; }; +template <> struct uint_bit< 32> { using type = uint32_t; }; +template <> struct uint_bit< 64> { using type = uint64_t; }; +template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; +template <> struct uint_bit<256> { using type = cutlass::uint256_t; }; + +template +using uint_bit_t = typename uint_bit::type; + +template +using uint_byte = uint_bit<8*N>; + +template +using uint_byte_t = typename uint_byte::type; + +} // namespace cute diff --git a/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp new file mode 100644 index 0000000000000000000000000000000000000000..799e18966233f0ca75d7bfad5c9e39c5c88d1d70 --- /dev/null +++ b/build/torch29-cxx11-cu126-aarch64-linux/include/third-party/cutlass/include/cute/numeric/integer_sequence.hpp @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute +{ + +using CUTE_STL_NAMESPACE::integer_sequence; +using CUTE_STL_NAMESPACE::make_integer_sequence; + +namespace detail { + +template +struct range_impl; + +template +struct range_impl, Begin> { + using type = integer_sequence; +}; + +template +struct reverse_impl; + +template +struct reverse_impl> { + using type = integer_sequence; +}; + +} // end namespace detail + +template +using make_integer_range = typename detail::range_impl< + T, + make_integer_sequence 0) ? (End-Begin) : 0>, + Begin>::type; + +template +using make_integer_sequence_reverse = typename detail::reverse_impl< + make_integer_sequence>::type; + +// +// Common aliases +// + +// int_sequence + +template +using int_sequence = integer_sequence; + +template +using make_int_sequence = make_integer_sequence; + +template +using make_int_rsequence = make_integer_sequence_reverse; + +template +using make_int_range = make_integer_range; + +// index_sequence + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +template +using make_index_rsequence = make_integer_sequence_reverse; + +template +using make_index_range = make_integer_range; + +// +// Shortcuts +// + +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using make_rseq = make_int_rsequence; + +template +using make_range = make_int_range; + +template +using tuple_seq = make_seq>::value>; + +template +using tuple_rseq = make_rseq>::value>; + +// +// Convert a parameter pack to an int sequence +// + +template +struct to_seq; + +template <> +struct to_seq> { + using type = seq<>; +}; + +template +struct to_seq> { + using type = seq; +}; + +template